Source code for veupath_chatbot.platform.store
"""Generic write-through store: in-memory cache + fire-and-forget DB persistence.
Provides the shared save/get/delete/aget/adelete logic so that concrete
stores only need to supply their ORM model class, row conversion functions,
and custom listing methods.
Subclasses define three class-level attributes:
* ``_model`` — SQLAlchemy ORM model class (e.g. ``ExperimentRow``)
* ``_to_row`` — callable mapping ``entity -> dict[str, object]`` for upsert
* ``_from_row`` — callable mapping ``row -> entity`` to reconstruct the domain object
The base class derives persist / load / delete from those, eliminating the
boilerplate that was previously duplicated across every concrete store.
"""
from collections.abc import Callable
from typing import Any, Protocol, cast
from sqlalchemy import delete as sa_delete
from sqlalchemy.dialects.postgresql import insert as pg_insert
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.tasks import spawn
logger = get_logger(__name__)
[docs]
class Identifiable(Protocol):
"""Any entity with a string ``id``."""
@property
def id(self) -> str: ...
[docs]
class WriteThruStore[T: Identifiable]:
"""In-memory cache backed by fire-and-forget DB writes.
Subclasses must set three class-level attributes:
* ``_model`` — SQLAlchemy ORM model (must have an ``id`` column)
* ``_to_row`` — ``(entity) -> dict`` of column values for upsert
* ``_from_row`` — ``(row) -> T`` to reconstruct the entity from a DB row
Every entity must satisfy the ``Identifiable`` protocol (have ``id: str``).
"""
_model: Any = None # SQLAlchemy ORM model class — set by subclass
_to_row: Callable[[T], dict[str, object]] = cast(
"Callable[[T], dict[str, object]]", cast(object, None)
)
_from_row: Callable[..., T] = cast("Callable[..., T]", cast(object, None))
[docs]
def __init__(self) -> None:
self._cache: dict[str, T] = {}
# -- DB helpers (derived from _model / _to_row / _from_row) ----------------
async def _persist(self, entity: T) -> None:
"""Upsert an entity row into the database."""
from veupath_chatbot.persistence.session import async_session_factory
try:
vals = self._to_row(entity)
stmt = (
pg_insert(self._model)
.values(**vals)
.on_conflict_do_update(
index_elements=[self._model.id],
set_={k: v for k, v in vals.items() if k != "id"},
)
)
async with async_session_factory() as session:
await session.execute(stmt)
await session.commit()
except Exception:
logger.exception(
"Failed to persist entity to DB",
entity_type=self._model.__tablename__,
entity_id=entity.id,
)
async def _load(self, entity_id: str) -> T | None:
"""Load a single entity from the database by primary key."""
from veupath_chatbot.persistence.session import async_session_factory
async with async_session_factory() as session:
row = await session.get(self._model, entity_id)
if row is None:
return None
return self._from_row(row)
async def _delete_from_db(self, entity_id: str) -> None:
"""Delete an entity row from the database."""
from veupath_chatbot.persistence.session import async_session_factory
stmt = sa_delete(self._model).where(self._model.id == entity_id)
async with async_session_factory() as session:
await session.execute(stmt)
await session.commit()
# -- Sync interface ---------------------------------------------------
[docs]
def save(self, entity: T) -> None:
self._cache[entity.id] = entity
spawn(self._persist(entity), name=f"persist-{entity.id}")
[docs]
def get(self, entity_id: str) -> T | None:
return self._cache.get(entity_id)
[docs]
def delete(self, entity_id: str) -> bool:
removed = self._cache.pop(entity_id, None) is not None
if removed:
spawn(self._delete_from_db(entity_id), name=f"delete-{entity_id}")
return removed
# -- Async interface --------------------------------------------------
[docs]
async def aget(self, entity_id: str) -> T | None:
entity = self._cache.get(entity_id)
if entity is not None:
return entity
entity = await self._load(entity_id)
if entity is not None:
self._cache[entity_id] = entity
return entity
[docs]
async def adelete(self, entity_id: str) -> bool:
removed = entity_id in self._cache
self._cache.pop(entity_id, None)
await self._delete_from_db(entity_id)
return removed