"""Operations endpoints: subscribe via Redis Streams, discover active operations."""
import json
from collections.abc import AsyncGenerator
from fastapi import APIRouter, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from veupath_chatbot.persistence.models import Operation, Stream
from veupath_chatbot.platform.errors import ForbiddenError, NotFoundError
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.redis import get_redis
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.transport.http.deps import CurrentUser, DBSession
from veupath_chatbot.transport.http.sse import SSE_HEADERS
logger = get_logger(__name__)
router = APIRouter(prefix="/api/v1/operations", tags=["operations"])
_EXPERIMENT_OP_TYPES = frozenset({"experiment", "batch", "benchmark"})
async def _verify_operation_access(
session: DBSession, op: Operation, user_id: CurrentUser
) -> None:
"""Verify the current user owns the stream for non-experiment operations."""
if op.type in _EXPERIMENT_OP_TYPES:
return
stream_result = await session.execute(
select(Stream.user_id).where(Stream.id == op.stream_id)
)
stream_owner = stream_result.scalar_one_or_none()
if stream_owner != user_id:
raise ForbiddenError()
# Event types that signal end of an operation.
_END_EVENT_TYPES = frozenset(
{
"message_end",
"experiment_end",
"batch_complete",
"batch_error",
"benchmark_complete",
"benchmark_error",
"seed_complete",
}
)
[docs]
@router.get("/{operation_id}/subscribe")
async def subscribe(
operation_id: str,
session: DBSession,
user_id: CurrentUser,
last_event_id: str | None = Query(
default=None,
alias="lastEventId",
description="Resume from this Redis entry ID (for reconnection).",
),
) -> StreamingResponse:
"""SSE stream backed by Redis Streams.
Catchup: replays events from `lastEventId` (or from the beginning).
Live: uses XREAD BLOCK for new events until a terminal event is seen.
"""
# Look up operation → stream mapping.
result = await session.execute(
select(Operation).where(Operation.operation_id == operation_id)
)
op = result.scalar_one_or_none()
if op is None:
raise NotFoundError(title="Operation not found")
await _verify_operation_access(session, op, user_id)
is_experiment = op.type in _EXPERIMENT_OP_TYPES
stream_key = f"op:{operation_id}" if is_experiment else f"stream:{op.stream_id}"
async def _stream() -> AsyncGenerator[str]:
redis = get_redis()
# Start position: after last_event_id, or from the beginning.
cursor = last_event_id if last_event_id else "0-0"
while True:
# XREAD with BLOCK — waits for new events, returns when available.
# Timeout of 15s triggers a keepalive comment.
entries = await redis.xread({stream_key: cursor}, count=1, block=15000)
if not entries:
# No events within timeout — send keepalive comment.
yield ":keepalive\n\n"
# Only check operation status on keepalive (no events).
# Checking after every event risks premature exit: a fast
# producer may mark the operation "completed" while unread
# events still sit in the stream.
try:
result = await session.execute(
select(Operation.status).where(
Operation.operation_id == operation_id
)
)
status = result.scalar_one_or_none()
if status and status != "active":
return
except Exception:
logger.warning(
"Failed to check operation status",
operation_id=operation_id,
exc_info=True,
)
continue
for _stream_name, events in entries:
for entry_id_bytes, fields in events:
entry_id = (
entry_id_bytes.decode()
if isinstance(entry_id_bytes, bytes)
else str(entry_id_bytes)
)
cursor = entry_id
# Filter by operation_id for shared streams (chat).
# Experiment streams are dedicated — no filtering needed.
if not is_experiment:
event_op = fields.get(b"op", b"").decode()
if event_op and event_op != operation_id:
continue
event_type = fields.get(b"type", b"progress").decode()
try:
data = json.loads(fields.get(b"data", b"{}"))
except json.JSONDecodeError:
logger.warning(
"Failed to parse event data",
operation_id=operation_id,
entry_id=entry_id,
)
data = {}
yield f"id: {entry_id}\nevent: {event_type}\ndata: {json.dumps(data)}\n\n"
if event_type in _END_EVENT_TYPES:
return
return StreamingResponse(
_stream(),
media_type="text/event-stream",
headers=SSE_HEADERS,
)
[docs]
@router.post("/{operation_id}/cancel", status_code=202)
async def cancel(
operation_id: str,
session: DBSession,
user_id: CurrentUser,
) -> JSONObject:
"""Cancel a running operation.
For chat operations this cancels the background asyncio task running
the LLM agent. The producer's CancelledError handler emits a
``message_end`` event so any connected subscribers close cleanly.
"""
from veupath_chatbot.services.chat.orchestrator import cancel_chat_operation
result = await session.execute(
select(Operation).where(Operation.operation_id == operation_id)
)
op = result.scalar_one_or_none()
if op is None:
raise NotFoundError(title="Operation not found")
await _verify_operation_access(session, op, user_id)
if op.status != "active":
return {"operationId": operation_id, "status": op.status, "cancelled": False}
cancelled = await cancel_chat_operation(operation_id)
return {"operationId": operation_id, "cancelled": cancelled}
[docs]
@router.get("/active")
async def list_active(
session: DBSession,
user_id: CurrentUser,
stream_id: str | None = Query(
default=None,
alias="streamId",
description="Filter by stream/strategy ID.",
),
type: str | None = Query(
default=None,
description="Filter by operation type (chat, experiment).",
),
) -> list[JSONObject]:
"""List active operations, optionally filtered by stream and/or type."""
stmt = (
select(Operation)
.join(Stream, Operation.stream_id == Stream.id)
.where(Operation.status == "active", Stream.user_id == user_id)
)
if stream_id:
stmt = stmt.where(Operation.stream_id == stream_id)
if type:
stmt = stmt.where(Operation.type == type)
result = await session.execute(stmt)
ops = result.scalars().all()
return [
{
"operationId": op.operation_id,
"streamId": str(op.stream_id),
"type": op.type,
"status": op.status,
"createdAt": op.created_at.isoformat() if op.created_at else None,
}
for op in ops
]