"""Event sourcing core: emit events to Redis + project to PostgreSQL."""
import json
from collections.abc import Callable
from datetime import UTC, datetime
from typing import cast
from redis.asyncio import Redis
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from veupath_chatbot.persistence.models import StreamProjection
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Plan AST helpers
# ---------------------------------------------------------------------------
def _steps_to_plan(
steps: list[JSONObject],
root_step_id: str,
snapshot: JSONObject,
) -> JSONObject | None:
"""Build a recursive plan AST from a flat steps list.
Returns None if the root step cannot be found.
"""
step_map: dict[str, JSONObject] = {}
for s in steps:
sid = s.get("id")
if isinstance(sid, str):
step_map[sid] = s
def build_node(step_id: str) -> JSONObject | None:
s = step_map.get(step_id)
if not s:
return None
raw_name = s.get("searchName") or ""
kind = str(s.get("kind") or "").strip().lower()
search_name = (
raw_name
if raw_name
else (
"__combine__"
if kind == "combine" or s.get("operator")
else "__unknown__"
)
)
node: JSONObject = {
"id": step_id,
"searchName": search_name,
"parameters": s.get("parameters", {}),
}
display = s.get("displayName")
if display:
node["displayName"] = display
op = s.get("operator")
if op:
node["operator"] = op
coloc = s.get("colocationParams")
if coloc:
node["colocationParams"] = coloc
primary_id = s.get("primaryInputStepId")
if isinstance(primary_id, str):
primary = build_node(primary_id)
if primary:
node["primaryInput"] = primary
secondary_id = s.get("secondaryInputStepId")
if isinstance(secondary_id, str):
secondary = build_node(secondary_id)
if secondary:
node["secondaryInput"] = secondary
return node
root = build_node(root_step_id)
if not root:
return None
return {
"recordType": snapshot.get("recordType", "transcript"),
"root": root,
"metadata": {"name": snapshot.get("name", "")},
}
def _count_plan_nodes(plan: JSONObject) -> int:
"""Count step nodes in a plan dict by walking the tree.
The plan dict has ``{"root": {..., "primaryInput": ..., "secondaryInput": ...}}``.
Each node with a ``searchName`` key counts as a step.
"""
root = plan.get("root")
if not isinstance(root, dict):
return 0
count = 0
def visit(node: JSONObject) -> None:
nonlocal count
if node.get("searchName"):
count += 1
primary = node.get("primaryInput")
if isinstance(primary, dict):
visit(primary)
secondary = node.get("secondaryInput")
if isinstance(secondary, dict):
visit(secondary)
visit(root)
return count
def _parse_arguments(raw: object) -> JSONObject:
"""Parse tool call arguments into a dict.
Arguments arrive as JSON strings from Redis but must be dicts for
``ToolCallResponse`` validation. Already-dict values pass through.
"""
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError, ValueError:
pass
return {}
# ---------------------------------------------------------------------------
# Event emission
# ---------------------------------------------------------------------------
# Event types where the PostgreSQL projection MUST survive crashes.
# After projecting one of these, we commit immediately so that Redis
# and PostgreSQL stay consistent even if the process dies mid-stream.
_COMMIT_AFTER = frozenset(
{
"strategy_link",
"graph_snapshot",
"graph_plan",
"model_selected",
}
)
[docs]
async def emit(
redis: Redis,
stream_id: str,
operation_id: str | None,
event_type: str,
event_data: JSONObject,
*,
session: AsyncSession | None = None,
) -> str:
"""Append an event to a Redis Stream and optionally project to PostgreSQL.
Returns the Redis entry ID (e.g. '1709234567890-0').
"""
entry_id_bytes: bytes = await redis.xadd(
f"stream:{stream_id}",
{
"op": (operation_id or "").encode(),
"type": event_type.encode(),
"data": json.dumps(event_data, default=str).encode(),
},
)
entry_id = (
entry_id_bytes.decode()
if isinstance(entry_id_bytes, bytes)
else str(entry_id_bytes)
)
if session:
await _project_event(session, stream_id, event_type, event_data, entry_id)
if event_type in _COMMIT_AFTER:
await session.commit()
return entry_id
# ---------------------------------------------------------------------------
# PostgreSQL projection — per-type handlers
# ---------------------------------------------------------------------------
# Event types that update the PostgreSQL projection. High-frequency
# streaming events (assistant_delta, tool_call_*, subkani_*, etc.) are
# skipped to avoid a DB round-trip per token.
_PROJECTED_EVENT_TYPES = frozenset(
{
"user_message",
"assistant_message",
"strategy_meta",
"strategy_link",
"graph_snapshot",
"graph_plan",
"model_selected",
"graph_cleared",
}
)
def _project_strategy_meta(updates: dict[str, object], data: JSONObject) -> None:
name = data.get("name")
if isinstance(name, str) and name:
updates["name"] = name
rt = data.get("recordType")
if isinstance(rt, str) and rt:
updates["record_type"] = rt
def _project_strategy_link(updates: dict[str, object], data: JSONObject) -> None:
wdk_id = data.get("wdkStrategyId")
if isinstance(wdk_id, int):
updates["wdk_strategy_id"] = wdk_id
is_saved = data.get("isSaved")
if isinstance(is_saved, bool):
updates["is_saved"] = is_saved
def _project_graph_snapshot(updates: dict[str, object], data: JSONObject) -> None:
snapshot = data.get("graphSnapshot")
if not isinstance(snapshot, dict):
return
steps = snapshot.get("steps")
if isinstance(steps, list):
updates["steps"] = steps
updates["step_count"] = len(steps)
root = snapshot.get("rootStepId")
if isinstance(root, str):
updates["root_step_id"] = root
name = snapshot.get("name") or snapshot.get("graphName")
if isinstance(name, str) and name:
updates["name"] = name
rt = snapshot.get("recordType")
if isinstance(rt, str) and rt:
updates["record_type"] = rt
# Build the plan AST from the flat steps when we have a root step ID.
# This covers the case where sub-kanis build steps (emitting snapshots)
# but never emit a graph_plan event.
if isinstance(steps, list) and isinstance(root, str) and steps:
typed_steps: list[JSONObject] = [s for s in steps if isinstance(s, dict)]
if typed_steps:
plan = _steps_to_plan(typed_steps, root, snapshot)
if plan:
updates["plan"] = plan
def _project_graph_plan(updates: dict[str, object], data: JSONObject) -> None:
plan_val = data.get("plan")
if isinstance(plan_val, dict):
updates["plan"] = plan_val
updates["step_count"] = _count_plan_nodes(plan_val)
name = data.get("name")
if isinstance(name, str) and name:
updates["name"] = name
rt = data.get("recordType")
if isinstance(rt, str) and rt:
updates["record_type"] = rt
def _project_model_selected(updates: dict[str, object], data: JSONObject) -> None:
model_id = data.get("modelId")
if isinstance(model_id, str):
updates["model_id"] = model_id
def _project_graph_cleared(updates: dict[str, object]) -> None:
from shared_py.defaults import DEFAULT_STREAM_NAME
updates["name"] = DEFAULT_STREAM_NAME
updates["plan"] = {}
updates["steps"] = []
updates["root_step_id"] = None
updates["step_count"] = 0
updates["wdk_strategy_id"] = None
updates["is_saved"] = False
# Dispatch table for handlers that take (updates, data).
_PROJECTION_HANDLERS: dict[str, Callable[[dict[str, object], JSONObject], None]] = {
"strategy_meta": _project_strategy_meta,
"strategy_link": _project_strategy_link,
"graph_snapshot": _project_graph_snapshot,
"graph_plan": _project_graph_plan,
"model_selected": _project_model_selected,
}
async def _project_event(
session: AsyncSession,
stream_id: str,
event_type: str,
event_data: JSONObject,
entry_id: str,
) -> None:
"""Update the PostgreSQL projection based on an event.
This is the ONLY code path that writes to stream_projections.
"""
if event_type not in _PROJECTED_EVENT_TYPES:
return
updates: dict[str, object] = {
"last_event_id": entry_id,
"updated_at": datetime.now(UTC),
}
if event_type in ("user_message", "assistant_message"):
updates["message_count"] = StreamProjection.__table__.c.message_count + 1
elif event_type == "graph_cleared":
_project_graph_cleared(updates)
else:
handler = _PROJECTION_HANDLERS.get(event_type)
if handler:
handler(updates, event_data)
# Pre-clear conflicting wdk_strategy_id before the main update.
# WDK can reuse strategy IDs (same user, same search), so when
# a new stream claims a WDK ID, the old owner must release it.
if event_type == "strategy_link":
wdk_id = event_data.get("wdkStrategyId")
if isinstance(wdk_id, int):
clear_stmt = (
update(StreamProjection)
.where(StreamProjection.wdk_strategy_id == wdk_id)
.where(StreamProjection.stream_id != stream_id)
.values(wdk_strategy_id=None, is_saved=False)
)
await session.execute(clear_stmt)
stmt = (
update(StreamProjection)
.where(StreamProjection.stream_id == stream_id)
.values(**updates)
)
try:
await session.execute(stmt)
await session.flush()
except Exception as exc:
# Handle wdk_strategy_id unique constraint race: two concurrent workers
# auto-build the same search -> both try to claim the same wdk_strategy_id.
# The first commit wins; the second gets IntegrityError. Clear the old
# owner and retry.
if "ix_proj_wdk" in str(exc) and "wdk_strategy_id" in updates:
await session.rollback()
wdk_id = cast(int, updates["wdk_strategy_id"])
clear_stmt = (
update(StreamProjection)
.where(StreamProjection.wdk_strategy_id == wdk_id)
.where(StreamProjection.stream_id != stream_id)
.values(wdk_strategy_id=None, is_saved=False)
)
await session.execute(clear_stmt)
await session.execute(stmt)
await session.flush()
else:
raise
# ---------------------------------------------------------------------------
# Stream reconstruction
# ---------------------------------------------------------------------------
def _entry_id_to_iso(entry_id: bytes | str) -> str:
"""Convert a Redis stream entry ID (e.g. '1709234567890-0') to ISO 8601."""
raw = entry_id.decode() if isinstance(entry_id, bytes) else str(entry_id)
ms_str = raw.split("-")[0]
try:
ts = datetime.fromtimestamp(int(ms_str) / 1000, tz=UTC)
except ValueError, OSError:
ts = datetime.now(UTC)
return ts.isoformat()
class _TurnAccumulator:
"""Accumulates metadata for a single assistant turn."""
__slots__ = (
"tool_calls",
"citations",
"planning_artifacts",
"reasoning",
"model_id",
"subkani_calls",
"subkani_status",
"subkani_models",
"subkani_token_usage",
)
def __init__(self) -> None:
self.tool_calls: list[JSONObject] = []
self.citations: list[JSONObject] = []
self.planning_artifacts: list[JSONObject] = []
self.reasoning: str | None = None
self.model_id: str | None = None
self.subkani_calls: dict[str, list[JSONObject]] = {}
self.subkani_status: dict[str, str] = {}
self.subkani_models: dict[str, str] = {}
self.subkani_token_usage: dict[str, JSONObject] = {}
def reset(self) -> None:
self.tool_calls.clear()
self.citations.clear()
self.planning_artifacts.clear()
self.reasoning = None
self.model_id = None
self.subkani_calls.clear()
self.subkani_status.clear()
self.subkani_models.clear()
self.subkani_token_usage.clear()
def build_assistant_message(
self,
data: JSONObject,
entry_id: bytes | str,
) -> JSONObject:
"""Build a complete assistant message with accumulated metadata."""
msg: JSONObject = {
"role": "assistant",
"content": data.get("content", ""),
"messageId": data.get("messageId"),
"timestamp": _entry_id_to_iso(entry_id),
}
if self.model_id:
msg["modelId"] = self.model_id
if self.tool_calls:
msg["toolCalls"] = list(self.tool_calls)
if self.citations:
msg["citations"] = list(self.citations)
if self.planning_artifacts:
msg["planningArtifacts"] = list(self.planning_artifacts)
if self.reasoning:
msg["reasoning"] = self.reasoning
if self.subkani_calls:
activity: JSONObject = {
"calls": {k: list(v) for k, v in self.subkani_calls.items()},
"status": dict(self.subkani_status),
}
if self.subkani_models:
activity["models"] = dict(self.subkani_models)
if self.subkani_token_usage:
activity["tokenUsage"] = dict(self.subkani_token_usage)
msg["subKaniActivity"] = activity
# Preserve any fields directly on the event data.
for key in ("citations", "planningArtifacts", "toolCalls", "reasoning"):
if key in data and key not in msg:
msg[key] = data[key]
return msg
[docs]
async def read_stream_messages(redis: Redis, stream_id: str) -> list[JSONObject]:
"""Read all user + assistant messages from a Redis stream.
Aggregates metadata from surrounding events (tool_call_start/end,
citations, planning_artifact, reasoning, subkani events) into each
assistant_message so the full conversation context survives refresh.
Used by the GET /strategies/{id} endpoint to return chat history.
"""
entries = await redis.xrange(f"stream:{stream_id}")
messages: list[JSONObject] = []
turn = _TurnAccumulator()
for entry_id, fields in entries:
event_type = fields.get(b"type", b"").decode()
try:
data = json.loads(fields[b"data"])
except json.JSONDecodeError, KeyError:
continue
match event_type:
case "message_start":
turn.reset()
case "user_message":
messages.append(
{
"role": "user",
"content": data.get("content", ""),
"messageId": data.get("messageId"),
"timestamp": _entry_id_to_iso(entry_id),
}
)
case "tool_call_start":
turn.tool_calls.append(
{
"id": data.get("id", ""),
"name": data.get("name", ""),
"arguments": _parse_arguments(data.get("arguments")),
}
)
case "tool_call_end":
call_id = data.get("id", "")
for tc in turn.tool_calls:
if tc["id"] == call_id:
tc["result"] = data.get("result")
break
case "citations":
cites = data.get("citations")
if isinstance(cites, list):
turn.citations.extend(cites)
case "planning_artifact":
artifact = data.get("planningArtifact")
if artifact:
turn.planning_artifacts.append(artifact)
case "reasoning":
r = data.get("reasoning")
if isinstance(r, str):
turn.reasoning = r
case "model_selected":
mid = data.get("modelId")
if isinstance(mid, str):
turn.model_id = mid
case "subkani_task_start":
task = data.get("task", "")
if task:
turn.subkani_status[task] = "running"
turn.subkani_calls.setdefault(task, [])
mid = data.get("modelId")
if isinstance(mid, str) and mid:
turn.subkani_models[task] = mid
case "subkani_tool_call_start":
task = data.get("task", "")
if task:
turn.subkani_calls.setdefault(task, []).append(
{
"id": data.get("id", ""),
"name": data.get("name", ""),
"arguments": _parse_arguments(data.get("arguments")),
}
)
case "subkani_tool_call_end":
task = data.get("task", "")
call_id = data.get("id", "")
for tc in turn.subkani_calls.get(task, []):
if tc["id"] == call_id:
tc["result"] = data.get("result")
break
case "subkani_task_end":
task = data.get("task", "")
if task:
turn.subkani_status[task] = data.get("status", "done")
mid = data.get("modelId")
if isinstance(mid, str) and mid:
turn.subkani_models[task] = mid
# Capture per-task token usage if present.
pt = data.get("promptTokens")
if pt is not None:
turn.subkani_token_usage[task] = {
"promptTokens": pt or 0,
"completionTokens": data.get("completionTokens", 0),
"llmCallCount": data.get("llmCallCount", 0),
"estimatedCostUsd": data.get("estimatedCostUsd", 0.0),
}
case "assistant_message":
messages.append(turn.build_assistant_message(data, entry_id))
case "message_end":
total = data.get("totalTokens", 0)
if isinstance(total, int) and total > 0:
token_usage: JSONObject = {
"promptTokens": data.get("promptTokens", 0),
"completionTokens": data.get("completionTokens", 0),
"totalTokens": total,
"cachedTokens": data.get("cachedTokens", 0),
"toolCallCount": data.get("toolCallCount", 0),
"registeredToolCount": data.get("registeredToolCount", 0),
"llmCallCount": data.get("llmCallCount", 0),
"subKaniPromptTokens": data.get("subKaniPromptTokens", 0),
"subKaniCompletionTokens": data.get(
"subKaniCompletionTokens", 0
),
"subKaniCallCount": data.get("subKaniCallCount", 0),
"estimatedCostUsd": data.get("estimatedCostUsd", 0.0),
"modelId": data.get("modelId", ""),
}
for i in range(len(messages) - 1, -1, -1):
if (
messages[i]["role"] == "user"
and "tokenUsage" not in messages[i]
):
messages[i]["tokenUsage"] = token_usage
break
for i in range(len(messages) - 1, -1, -1):
if (
messages[i]["role"] == "assistant"
and "tokenUsage" not in messages[i]
):
messages[i]["tokenUsage"] = token_usage
break
turn.reset()
return messages
[docs]
async def read_stream_thinking(redis: Redis, stream_id: str) -> JSONObject | None:
"""Derive in-progress thinking state from stream events.
Thinking = tool_call_start events without matching tool_call_end,
from the most recent active operation.
"""
entries = await redis.xrange(f"stream:{stream_id}")
# Find the last message_start (marks beginning of a turn)
last_start_idx = -1
for i, (_eid, fields) in enumerate(entries):
if fields.get(b"type", b"") == b"message_start":
last_start_idx = i
if last_start_idx < 0:
return None
# Check if this turn is still active (no message_end after last start)
has_end = any(
fields.get(b"type", b"") == b"message_end"
for _eid, fields in entries[last_start_idx:]
)
if has_end:
return None
# Collect in-progress tool calls
open_tools: dict[str, JSONObject] = {}
for _eid, fields in entries[last_start_idx:]:
event_type = fields.get(b"type", b"").decode()
if event_type == "tool_call_start":
try:
data = json.loads(fields[b"data"])
call_id = data.get("id", "")
open_tools[call_id] = data
except json.JSONDecodeError, KeyError:
pass
elif event_type == "tool_call_end":
try:
data = json.loads(fields[b"data"])
call_id = data.get("id", "")
open_tools.pop(call_id, None)
except json.JSONDecodeError, KeyError:
pass
if not open_tools:
return None
return {
"toolCalls": list(open_tools.values()),
}