Source code for veupath_chatbot.persistence.models

"""SQLAlchemy ORM models."""

from datetime import datetime
from uuid import UUID, uuid4

from sqlalchemy import (
    JSON,
    Boolean,
    DateTime,
    ForeignKey,
    Index,
    Integer,
    String,
    Text,
    func,
)
from sqlalchemy.engine import Dialect
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.types import CHAR, TypeDecorator, TypeEngine

from veupath_chatbot.platform.types import JSONArray, JSONObject


[docs] class GUID(TypeDecorator[UUID]): """Platform-independent GUID type. Uses CHAR(36) and stores UUIDs as strings. Returns proper ``UUID`` objects on read so that Python-side comparisons (e.g. ``stream.user_id == some_uuid``) work correctly. """ impl = CHAR cache_ok = True
[docs] def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[str]: return dialect.type_descriptor(CHAR(36))
[docs] def process_bind_param( self, value: UUID | str | None, dialect: Dialect ) -> str | None: if value is None: return None if isinstance(value, UUID): return str(value) return value
[docs] def process_result_value( self, value: str | UUID | None, dialect: Dialect ) -> UUID | None: if value is None: return None if isinstance(value, UUID): return value return UUID(value)
[docs] class Base(DeclarativeBase): """Base class for all models.""" type_annotation_map = { JSONObject: JSON, JSONArray: JSON, UUID: GUID, }
[docs] class User(Base): """User model for tracking strategies.""" __tablename__ = "users" id: Mapped[UUID] = mapped_column(GUID(), primary_key=True, default=uuid4) external_id: Mapped[str | None] = mapped_column(String(255), unique=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) # Relationships streams: Mapped[list[Stream]] = relationship( back_populates="user", cascade="all, delete-orphan" )
[docs] class ControlSet(Base): """Reusable control gene set with provenance metadata.""" __tablename__ = "control_sets" id: Mapped[UUID] = mapped_column(GUID(), primary_key=True, default=uuid4) user_id: Mapped[UUID | None] = mapped_column( GUID(), ForeignKey("users.id", ondelete="SET NULL"), nullable=True ) name: Mapped[str] = mapped_column(String(255)) site_id: Mapped[str] = mapped_column(String(100)) record_type: Mapped[str] = mapped_column(String(100)) positive_ids: Mapped[JSONArray] = mapped_column(JSON, default=list) negative_ids: Mapped[JSONArray] = mapped_column(JSON, default=list) source: Mapped[str | None] = mapped_column(String(50)) tags: Mapped[JSONArray] = mapped_column(JSON, default=list) provenance_notes: Mapped[str | None] = mapped_column(Text) version: Mapped[int] = mapped_column(Integer, default=1) is_public: Mapped[bool] = mapped_column(Boolean, default=False) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) __table_args__ = ( Index("ix_control_sets_site_id", "site_id"), Index("ix_control_sets_user_id", "user_id"), )
[docs] class ExperimentRow(Base): """Persisted experiment with full JSON blob.""" __tablename__ = "experiments" id: Mapped[str] = mapped_column(String(50), primary_key=True) site_id: Mapped[str] = mapped_column(String(100)) user_id: Mapped[str | None] = mapped_column(String(36), nullable=True) name: Mapped[str] = mapped_column(String(255), default="") status: Mapped[str] = mapped_column(String(20), default="pending") data: Mapped[JSONObject] = mapped_column(JSON, default=dict) batch_id: Mapped[str | None] = mapped_column(String(50), nullable=True) benchmark_id: Mapped[str | None] = mapped_column(String(50), nullable=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) __table_args__ = ( Index("ix_experiments_site_id", "site_id"), Index("ix_experiments_user_id", "user_id"), Index("ix_experiments_batch_id", "batch_id"), Index("ix_experiments_benchmark_id", "benchmark_id"), )
[docs] class Stream(Base): """A conversation stream — the identity of a chat conversation. All mutable state is derived from events in Redis. This table only holds identity and ownership. """ __tablename__ = "streams" id: Mapped[UUID] = mapped_column(GUID(), primary_key=True, default=uuid4) user_id: Mapped[UUID] = mapped_column( GUID(), ForeignKey("users.id", ondelete="CASCADE") ) site_id: Mapped[str] = mapped_column(String(50)) experiment_id: Mapped[str | None] = mapped_column(String(50), nullable=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) user: Mapped[User] = relationship(back_populates="streams") __table_args__ = ( Index("ix_streams_user_site", "user_id", "site_id"), Index("ix_streams_experiment", "user_id", "experiment_id"), )
[docs] class StreamProjection(Base): """Materialized projection of a conversation stream. Derived from events — rebuildable by replaying the Redis stream. This is a CACHE for fast reads, not a source of truth. """ __tablename__ = "stream_projections" stream_id: Mapped[UUID] = mapped_column( GUID(), ForeignKey("streams.id", ondelete="CASCADE"), primary_key=True ) name: Mapped[str] = mapped_column(String(255), default="") record_type: Mapped[str | None] = mapped_column(String(100), nullable=True) wdk_strategy_id: Mapped[int | None] = mapped_column(nullable=True) is_saved: Mapped[bool] = mapped_column(Boolean, default=False) model_id: Mapped[str | None] = mapped_column(String(100), nullable=True) message_count: Mapped[int] = mapped_column(Integer, default=0) step_count: Mapped[int] = mapped_column(Integer, default=0) plan: Mapped[JSONObject] = mapped_column(JSON, default=dict) steps: Mapped[JSONArray] = mapped_column(JSON, default=list) root_step_id: Mapped[str | None] = mapped_column(String(100), nullable=True) result_count: Mapped[int | None] = mapped_column(nullable=True) last_event_id: Mapped[str | None] = mapped_column(String(30), nullable=True) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) dismissed_at: Mapped[datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) site_id: Mapped[str] = mapped_column(String(50), default="") # Gene set auto-import association gene_set_id: Mapped[str | None] = mapped_column( String(50), ForeignKey("gene_sets.id", ondelete="SET NULL"), nullable=True, ) gene_set_auto_imported: Mapped[bool] = mapped_column(Boolean, default=False) stream: Mapped[Stream] = relationship() __table_args__ = ( Index( "ix_proj_wdk", "wdk_strategy_id", unique=True, postgresql_where="wdk_strategy_id IS NOT NULL", ), )
[docs] class GeneSetRow(Base): """Persisted gene set for workbench analysis.""" __tablename__ = "gene_sets" id: Mapped[str] = mapped_column(String(50), primary_key=True) user_id: Mapped[str | None] = mapped_column(String(36), nullable=True) site_id: Mapped[str] = mapped_column(String(100)) name: Mapped[str] = mapped_column(String(255), default="") gene_ids: Mapped[JSONArray] = mapped_column(JSON, default=list) source: Mapped[str] = mapped_column(String(20), default="paste") wdk_strategy_id: Mapped[int | None] = mapped_column(nullable=True) wdk_step_id: Mapped[int | None] = mapped_column(nullable=True) search_name: Mapped[str | None] = mapped_column(String(255), nullable=True) record_type: Mapped[str | None] = mapped_column(String(100), nullable=True) parameters: Mapped[JSONObject | None] = mapped_column(JSON, nullable=True) parent_set_ids: Mapped[JSONArray] = mapped_column(JSON, default=list) operation: Mapped[str | None] = mapped_column(String(20), nullable=True) step_count: Mapped[int] = mapped_column(Integer, default=1) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) __table_args__ = ( Index("ix_gene_sets_user_id", "user_id"), Index("ix_gene_sets_site_id", "site_id"), Index("ix_gene_sets_user_site", "user_id", "site_id"), )
[docs] class Operation(Base): """Tracks active and completed operations for client discovery.""" __tablename__ = "operations" operation_id: Mapped[str] = mapped_column(String(32), primary_key=True) stream_id: Mapped[UUID] = mapped_column( GUID(), ForeignKey("streams.id", ondelete="CASCADE") ) type: Mapped[str] = mapped_column(String(50)) status: Mapped[str] = mapped_column(String(20), default="active") created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) completed_at: Mapped[datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) stream: Mapped[Stream] = relationship() __table_args__ = (Index("ix_ops_stream_status", "stream_id", "status"),)