Source code for veupath_chatbot.transport.http.deps
"""Dependency injection for HTTP routes."""
import asyncio
from typing import Annotated
from uuid import UUID
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from veupath_chatbot.persistence.repositories import (
ControlSetRepository,
StreamRepository,
UserRepository,
)
from veupath_chatbot.persistence.session import get_db_session
from veupath_chatbot.platform.errors import ForbiddenError, NotFoundError
from veupath_chatbot.platform.security import get_current_user
from veupath_chatbot.services.experiment.store import get_experiment_store
from veupath_chatbot.services.experiment.types import Experiment
# Type aliases for dependencies
DBSession = Annotated[AsyncSession, Depends(get_db_session)]
[docs]
async def get_user_repo(session: DBSession) -> UserRepository:
"""Get user repository."""
return UserRepository(session)
[docs]
async def get_control_set_repo(session: DBSession) -> ControlSetRepository:
"""Get control set repository."""
return ControlSetRepository(session)
[docs]
async def get_stream_repo(session: DBSession) -> StreamRepository:
"""Get stream repository."""
return StreamRepository(session)
UserRepo = Annotated[UserRepository, Depends(get_user_repo)]
ControlSetRepo = Annotated[ControlSetRepository, Depends(get_control_set_repo)]
StreamRepo = Annotated[StreamRepository, Depends(get_stream_repo)]
[docs]
async def get_current_user_with_db_row(
user_id: Annotated[UUID, Depends(get_current_user)],
user_repo: UserRepo,
) -> UUID:
"""Ensure authenticated users exist in the local DB.
We persist user IDs because many tables have a FK to `users.id`. Without this,
first-time sessions can trigger integrity errors that bubble up as 500s.
"""
await user_repo.get_or_create(user_id)
return user_id
CurrentUser = Annotated[UUID, Depends(get_current_user_with_db_row)]
[docs]
async def get_experiment_owned_by_user(
experiment_id: str,
user_id: CurrentUser,
) -> Experiment:
"""Resolve an experiment by ID and verify the current user owns it."""
store = get_experiment_store()
exp = await store.aget(experiment_id)
if not exp:
raise NotFoundError(title="Experiment not found")
if exp.user_id != str(user_id):
raise ForbiddenError(title="Not authorized to access this experiment")
return exp
ExperimentDep = Annotated[Experiment, Depends(get_experiment_owned_by_user)]
[docs]
async def get_experiments_owned_by_user(
experiment_ids: list[str], user_id: str
) -> list[Experiment]:
"""Fetch multiple experiments by ID and verify ownership (parallel)."""
store = get_experiment_store()
fetched = await asyncio.gather(*(store.aget(eid) for eid in experiment_ids))
experiments: list[Experiment] = []
for eid, exp in zip(experiment_ids, fetched, strict=True):
if not exp:
raise NotFoundError(title=f"Experiment {eid} not found")
if exp.user_id != user_id:
raise ForbiddenError(title="Not authorized to access this experiment")
experiments.append(exp)
return experiments