Source code for veupath_chatbot.transport.http.routers.gene_sets

"""Gene set management endpoints.

Thin transport layer: parse HTTP request, call service, return HTTP response.
All business logic lives in ``services.gene_sets.operations``.
"""

from typing import Literal, cast, get_args

from fastapi import APIRouter, Query, Request

from veupath_chatbot.platform.errors import (
    InternalError,
    NotFoundError,
    ValidationError,
)
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.security import limiter
from veupath_chatbot.platform.types import JSONObject, JSONValue
from veupath_chatbot.services.experiment.types import to_json
from veupath_chatbot.services.gene_sets.confidence import (
    compute_gene_confidence,
)
from veupath_chatbot.services.gene_sets.ensemble import (
    EnsembleScore,
    compute_ensemble_scores,
)
from veupath_chatbot.services.gene_sets.operations import GeneSetService
from veupath_chatbot.services.gene_sets.reverse_search import (
    GeneSetCandidate,
    rank_gene_sets_by_recall,
)
from veupath_chatbot.services.gene_sets.store import get_gene_set_store
from veupath_chatbot.services.gene_sets.types import GeneSet
from veupath_chatbot.transport.http.deps import CurrentUser
from veupath_chatbot.transport.http.schemas.gene_sets import (
    CreateGeneSetRequest,
    EnsembleScoringRequest,
    GeneConfidenceRequest,
    GeneConfidenceScoreResponse,
    GeneSetEnrichRequest,
    GeneSetResponse,
    ReverseSearchRequest,
    ReverseSearchResultItem,
    SetOperation,
    SetOperationRequest,
)
from veupath_chatbot.transport.http.schemas.steps import RecordDetailRequest

router = APIRouter(prefix="/api/v1/gene-sets", tags=["gene-sets"])
logger = get_logger(__name__)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _svc() -> GeneSetService:
    return GeneSetService(get_gene_set_store())


def _to_response(gs: GeneSet) -> GeneSetResponse:
    valid_ops = get_args(SetOperation)
    operation: SetOperation | None = (
        cast(SetOperation, gs.operation) if gs.operation in valid_ops else None
    )
    return GeneSetResponse(
        id=gs.id,
        siteId=gs.site_id,
        name=gs.name,
        geneIds=gs.gene_ids,
        source=gs.source,
        geneCount=len(gs.gene_ids),
        wdkStrategyId=gs.wdk_strategy_id,
        wdkStepId=gs.wdk_step_id,
        searchName=gs.search_name,
        recordType=gs.record_type,
        parameters=gs.parameters,
        parentSetIds=gs.parent_set_ids,
        operation=operation,
        createdAt=gs.created_at.isoformat(),
        stepCount=gs.step_count,
    )


def _not_found(exc: KeyError) -> NotFoundError:
    return NotFoundError(title=str(exc))


def _no_strategy(exc: ValueError) -> NotFoundError:
    return NotFoundError(title="No WDK strategy", detail=str(exc))


# ---------------------------------------------------------------------------
# CRUD endpoints
# ---------------------------------------------------------------------------


[docs] @router.post("", status_code=201) @limiter.limit("30/minute") async def create_gene_set( request: Request, body: CreateGeneSetRequest, user_id: CurrentUser, ) -> GeneSetResponse: """Create a new gene set.""" gs = await _svc().create( user_id=user_id, name=body.name, site_id=body.site_id, gene_ids=body.gene_ids, source=body.source, wdk_strategy_id=body.wdk_strategy_id, wdk_step_id=body.wdk_step_id, search_name=body.search_name, record_type=body.record_type, parameters=body.parameters, ) return _to_response(gs)
[docs] @router.get("") async def list_gene_sets( user_id: CurrentUser, site_id: str | None = Query(None, alias="siteId"), ) -> list[GeneSetResponse]: """List all gene sets for the current user, optionally filtered by site.""" sets = await _svc().list_for_user(user_id, site_id=site_id) return [_to_response(gs) for gs in sets]
[docs] @router.get("/{gene_set_id}") async def get_gene_set( gene_set_id: str, user_id: CurrentUser, ) -> GeneSetResponse: """Get a gene set by ID.""" try: gs = await _svc().get_for_user(user_id, gene_set_id) except KeyError as exc: raise _not_found(exc) from exc return _to_response(gs)
[docs] @router.delete("/{gene_set_id}") async def delete_gene_set( gene_set_id: str, user_id: CurrentUser, ) -> dict[str, bool]: """Delete a gene set.""" try: await _svc().delete(user_id, gene_set_id) except KeyError as exc: raise _not_found(exc) from exc return {"ok": True}
[docs] @router.post("/operations") async def set_operations( request: SetOperationRequest, user_id: CurrentUser, ) -> GeneSetResponse: """Perform set operations (intersect, union, minus) between two gene sets.""" try: gs = await _svc().perform_set_operation( user_id=user_id, set_a_id=request.set_a_id, set_b_id=request.set_b_id, operation=request.operation, name=request.name, ) except KeyError as exc: raise _not_found(exc) from exc except ValueError as exc: raise ValidationError(title="Invalid operation", detail=str(exc)) from exc return _to_response(gs)
# --------------------------------------------------------------------------- # Reverse search # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # Ensemble scoring # ---------------------------------------------------------------------------
[docs] @router.post("/ensemble") async def ensemble_scoring( body: EnsembleScoringRequest, user_id: CurrentUser, ) -> list[EnsembleScore]: """Score genes by frequency across multiple gene sets.""" service = _svc() gene_sets: list[list[str]] = [] for gs_id in body.gene_set_ids: try: gs = await service.get_for_user(user_id, gs_id) except KeyError as exc: raise _not_found(exc) from exc gene_sets.append(gs.gene_ids) return compute_ensemble_scores(gene_sets, body.positive_controls)
# --------------------------------------------------------------------------- # Enrichment # ---------------------------------------------------------------------------
[docs] @router.post("/{gene_set_id}/enrich") async def enrich_gene_set( gene_set_id: str, request: GeneSetEnrichRequest, user_id: CurrentUser, ) -> list[JSONObject]: """Run enrichment analysis on a gene set.""" try: results = await _svc().run_enrichment( user_id, gene_set_id, request.enrichment_types ) except KeyError as exc: raise _not_found(exc) from exc except RuntimeError as exc: raise InternalError( title="Enrichment analysis failed", detail=str(exc) ) from exc return [to_json(r) for r in results]
# --------------------------------------------------------------------------- # Result browsing endpoints (attributes, records, distributions) # ---------------------------------------------------------------------------
[docs] @router.get("/{gene_set_id}/results/attributes") async def get_gene_set_attributes( gene_set_id: str, user_id: CurrentUser, ) -> JSONObject: """Get available attributes for a gene set's record type.""" try: svc = await _svc().get_step_results_service(user_id, gene_set_id) except KeyError as exc: raise _not_found(exc) from exc except ValueError as exc: raise _no_strategy(exc) from exc return await svc.get_attributes()
[docs] @router.get("/{gene_set_id}/results/records") async def get_gene_set_records( gene_set_id: str, user_id: CurrentUser, offset: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=500), sort: str | None = None, dir: Literal["ASC", "DESC"] = "ASC", attributes: str | None = None, filter_attribute: str | None = Query(None, alias="filterAttribute"), filter_value: str | None = Query(None, alias="filterValue"), ) -> JSONObject: """Get paginated result records for a gene set.""" try: svc = await _svc().get_step_results_service(user_id, gene_set_id) except KeyError as exc: raise _not_found(exc) from exc except ValueError as exc: raise _no_strategy(exc) from exc attr_list: list[str] | None = None if attributes: attr_list = [a.strip() for a in attributes.split(",") if a.strip()] # When filtering, fetch all records for the step and filter server-side. if filter_attribute and filter_value is not None: answer = await svc.get_records( offset=0, limit=10_000, sort=sort, direction=dir, attributes=attr_list, ) records = answer.get("records", []) if not isinstance(records, list): records = [] filtered: list[JSONValue] = [] for r in records: if not isinstance(r, dict): continue attrs = r.get("attributes") if isinstance(attrs, dict) and attrs.get(filter_attribute) == filter_value: filtered.append(r) page = filtered[offset : offset + limit] return { "records": cast(JSONValue, page), "meta": { "totalCount": len(filtered), "displayTotalCount": len(filtered), "responseCount": len(page), "pagination": {"offset": offset, "numRecords": limit}, "attributes": cast(JSONValue, attr_list or []), "tables": cast(JSONValue, []), }, } return await svc.get_records( offset=offset, limit=limit, sort=sort, direction=dir, attributes=attr_list, )
[docs] @router.get("/{gene_set_id}/results/distributions/{attribute_name}") async def get_gene_set_distribution( gene_set_id: str, attribute_name: str, user_id: CurrentUser, ) -> JSONObject: """Get distribution data for an attribute using the byValue column reporter.""" try: svc = await _svc().get_step_results_service(user_id, gene_set_id) except KeyError as exc: raise _not_found(exc) from exc except ValueError as exc: raise _no_strategy(exc) from exc return await svc.get_distribution(attribute_name)
[docs] @router.post("/{gene_set_id}/results/record") async def get_gene_set_record_detail( gene_set_id: str, body: RecordDetailRequest, user_id: CurrentUser, ) -> JSONObject: """Get a single record's full details by primary key.""" service = _svc() try: gs = await service.get_for_user(user_id, gene_set_id) svc = await service.get_step_results_service(user_id, gene_set_id) except KeyError as exc: raise _not_found(exc) from exc except ValueError as exc: raise _no_strategy(exc) from exc pk_parts: list[JSONObject] = [ {"name": part.name, "value": part.value} for part in body.primary_key ] return await svc.get_record_detail(pk_parts, gs.site_id)
# --------------------------------------------------------------------------- # Confidence scoring # ---------------------------------------------------------------------------
[docs] @router.post("/confidence") async def gene_confidence( body: GeneConfidenceRequest, ) -> list[GeneConfidenceScoreResponse]: """Compute per-gene composite confidence scores from classification data.""" scores = compute_gene_confidence( tp_ids=body.tp_ids, fp_ids=body.fp_ids, fn_ids=body.fn_ids, tn_ids=body.tn_ids, ensemble_scores=body.ensemble_scores, enrichment_gene_counts=body.enrichment_gene_counts, max_enrichment_terms=body.max_enrichment_terms, ) return [ GeneConfidenceScoreResponse( geneId=s.gene_id, compositeScore=s.composite_score, classificationScore=s.classification_score, ensembleScore=s.ensemble_score, enrichmentScore=s.enrichment_score, ) for s in scores ]