"""Repository for stream (conversation) identity + projections."""
from datetime import UTC, datetime
from typing import Any
from uuid import UUID, uuid4
from shared_py.defaults import DEFAULT_STREAM_NAME
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from veupath_chatbot.persistence.models import Operation, Stream, StreamProjection
from veupath_chatbot.platform.types import JSONObject
[docs]
class StreamRepository:
"""Data access for conversation streams and their projections."""
[docs]
def __init__(self, session: AsyncSession) -> None:
self.session = session
# ── Helpers ──
async def _deduplicate_name(
self,
user_id: UUID,
site_id: str,
name: str,
exclude_stream_id: UUID | None = None,
) -> str:
"""Return a unique name for a projection within (user, site).
If ``name`` already exists, appends ``(1)``, ``(2)``, etc.
"""
query = (
select(StreamProjection.name)
.join(Stream)
.where(Stream.user_id == user_id, Stream.site_id == site_id)
)
if exclude_stream_id is not None:
query = query.where(StreamProjection.stream_id != exclude_stream_id)
result = await self.session.execute(query)
existing: set[str] = {row[0] for row in result.all() if row[0]}
if name not in existing:
return name
i = 1
while f"{name} ({i})" in existing:
i += 1
return f"{name} ({i})"
# ── Identity ──
[docs]
async def create(
self,
user_id: UUID,
site_id: str,
*,
stream_id: UUID | None = None,
name: str = "",
experiment_id: str | None = None,
) -> Stream:
resolved_name = await self._deduplicate_name(
user_id,
site_id,
name or DEFAULT_STREAM_NAME,
)
stream = Stream(
id=stream_id or uuid4(),
user_id=user_id,
site_id=site_id,
experiment_id=experiment_id,
)
self.session.add(stream)
await self.session.flush()
proj = StreamProjection(
stream_id=stream.id,
name=resolved_name,
site_id=site_id,
)
self.session.add(proj)
await self.session.flush()
return stream
[docs]
async def get_by_id(self, stream_id: UUID) -> Stream | None:
result = await self.session.execute(
select(Stream).where(Stream.id == stream_id)
)
return result.scalar_one_or_none()
[docs]
async def find_by_experiment(
self, user_id: UUID, experiment_id: str
) -> Stream | None:
"""Find an existing stream for a user + experiment combination."""
result = await self.session.execute(
select(Stream).where(
Stream.user_id == user_id,
Stream.experiment_id == experiment_id,
)
)
return result.scalar_one_or_none()
[docs]
async def delete(self, stream_id: UUID) -> None:
await self.session.execute(delete(Stream).where(Stream.id == stream_id))
# ── Projections ──
[docs]
async def get_projection(self, stream_id: UUID) -> StreamProjection | None:
result = await self.session.execute(
select(StreamProjection)
.options(joinedload(StreamProjection.stream))
.where(StreamProjection.stream_id == stream_id)
)
return result.scalar_one_or_none()
[docs]
async def list_projections(
self,
user_id: UUID,
site_id: str | None = None,
limit: int = 50,
) -> list[StreamProjection]:
stmt = (
select(StreamProjection)
.join(Stream)
.options(joinedload(StreamProjection.stream))
.where(Stream.user_id == user_id)
.where(StreamProjection.dismissed_at.is_(None))
.order_by(StreamProjection.updated_at.desc())
.limit(limit)
)
if site_id:
stmt = stmt.where(Stream.site_id == site_id)
result = await self.session.execute(stmt)
return list(result.unique().scalars().all())
[docs]
async def get_by_wdk_strategy_id(
self, user_id: UUID, wdk_strategy_id: int
) -> StreamProjection | None:
result = await self.session.execute(
select(StreamProjection)
.join(Stream)
.options(joinedload(StreamProjection.stream))
.where(
Stream.user_id == user_id,
StreamProjection.wdk_strategy_id == wdk_strategy_id,
)
)
return result.scalar_one_or_none()
[docs]
async def update_projection(
self,
stream_id: UUID,
*,
name: str | None = None,
record_type: str | None = None,
wdk_strategy_id: int | None = None,
wdk_strategy_id_set: bool = False,
is_saved: bool | None = None,
is_saved_set: bool = False,
plan: JSONObject | None = None,
step_count: int | None = None,
result_count: int | None = None,
result_count_set: bool = False,
gene_set_id: str | None = None,
gene_set_id_set: bool = False,
gene_set_auto_imported: bool | None = None,
) -> None:
"""Dynamically update a StreamProjection based on provided kwargs.
Steps and root_step_id are derived from plan at read time; only plan
and a denormalized step_count are persisted on write.
"""
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
if name is not None:
# Deduplicate rename against other projections for the same user+site.
proj = await self.get_projection(stream_id)
if proj and proj.stream:
name = await self._deduplicate_name(
proj.stream.user_id,
proj.stream.site_id,
name,
exclude_stream_id=stream_id,
)
values["name"] = name
if record_type is not None:
values["record_type"] = record_type
if wdk_strategy_id_set:
values["wdk_strategy_id"] = wdk_strategy_id
if is_saved_set:
values["is_saved"] = bool(is_saved)
if plan is not None:
values["plan"] = plan
if step_count is not None:
values["step_count"] = step_count
if result_count_set:
values["result_count"] = result_count
if gene_set_id_set:
values["gene_set_id"] = gene_set_id
if gene_set_auto_imported is not None:
values["gene_set_auto_imported"] = gene_set_auto_imported
await self.session.execute(
update(StreamProjection)
.where(StreamProjection.stream_id == stream_id)
.values(**values)
)
await self.session.flush()
[docs]
async def dismiss(self, stream_id: UUID) -> None:
"""Soft-delete: mark a projection as dismissed (hidden from main list)."""
await self.session.execute(
update(StreamProjection)
.where(StreamProjection.stream_id == stream_id)
.values(dismissed_at=datetime.now(UTC))
)
await self.session.flush()
[docs]
async def restore(self, stream_id: UUID) -> None:
"""Un-dismiss: restore a dismissed projection and reset for fresh WDK import."""
await self.session.execute(
update(StreamProjection)
.where(StreamProjection.stream_id == stream_id)
.values(
dismissed_at=None,
plan={},
message_count=0,
)
)
await self.session.flush()
[docs]
async def list_dismissed_projections(
self,
user_id: UUID,
site_id: str | None = None,
limit: int = 50,
) -> list[StreamProjection]:
stmt = (
select(StreamProjection)
.join(Stream)
.options(joinedload(StreamProjection.stream))
.where(Stream.user_id == user_id)
.where(StreamProjection.dismissed_at.isnot(None))
.order_by(StreamProjection.dismissed_at.desc())
.limit(limit)
)
if site_id:
stmt = stmt.where(Stream.site_id == site_id)
result = await self.session.execute(stmt)
return list(result.unique().scalars().all())
[docs]
async def prune_wdk_orphans(
self,
user_id: UUID,
site_id: str,
live_wdk_ids: set[int],
) -> int:
"""Delete streams whose projections have wdk_strategy_id not in the live set.
Returns the number of pruned streams.
"""
# Single query: fetch projections with wdk links, filtering in SQL.
stmt = (
select(StreamProjection.stream_id, StreamProjection.wdk_strategy_id)
.join(Stream)
.where(
Stream.user_id == user_id,
Stream.site_id == site_id,
StreamProjection.wdk_strategy_id.isnot(None),
StreamProjection.dismissed_at.is_(None),
)
)
result = await self.session.execute(stmt)
rows = result.all()
# Filter to orphans whose wdk_strategy_id is not in the live set.
orphan_ids = [
stream_id for stream_id, wdk_id in rows if wdk_id not in live_wdk_ids
]
if not orphan_ids:
return 0
# Batch delete all orphan streams (cascade deletes projections).
await self.session.execute(delete(Stream).where(Stream.id.in_(orphan_ids)))
await self.session.flush()
return len(orphan_ids)
# ── Operations ──
[docs]
async def register_operation(
self, operation_id: str, stream_id: UUID, op_type: str
) -> Operation:
op = Operation(
operation_id=operation_id,
stream_id=stream_id,
type=op_type,
status="active",
)
self.session.add(op)
await self.session.flush()
return op
async def _set_operation_status(self, operation_id: str, status: str) -> None:
await self.session.execute(
update(Operation)
.where(Operation.operation_id == operation_id)
.values(status=status, completed_at=datetime.now(UTC))
)
await self.session.flush()
[docs]
async def complete_operation(self, operation_id: str) -> None:
await self._set_operation_status(operation_id, "completed")
[docs]
async def fail_operation(self, operation_id: str) -> None:
await self._set_operation_status(operation_id, "failed")
[docs]
async def cancel_operation(self, operation_id: str) -> None:
await self._set_operation_status(operation_id, "cancelled")
[docs]
async def get_active_operations(self, stream_id: UUID) -> list[Operation]:
result = await self.session.execute(
select(Operation)
.where(Operation.stream_id == stream_id, Operation.status == "active")
.order_by(Operation.created_at.desc())
)
return list(result.scalars().all())
[docs]
async def list_active_operations(
self, op_type: str | None = None
) -> list[Operation]:
stmt = select(Operation).where(Operation.status == "active")
if op_type:
stmt = stmt.where(Operation.type == op_type)
result = await self.session.execute(stmt)
return list(result.scalars().all())