Source code for veupath_chatbot.ai.tools.export_tools

"""AI tools for exporting data as downloadable files."""

from typing import Annotated, Literal, cast
from uuid import UUID

from kani import AIParam, ai_function

from veupath_chatbot.platform.errors import ErrorCode
from veupath_chatbot.platform.tool_errors import tool_error
from veupath_chatbot.platform.types import JSONObject, JSONValue
from veupath_chatbot.services.export import get_export_service
from veupath_chatbot.services.gene_sets.store import get_gene_set_store


[docs] class ExportToolsMixin: """Kani tool mixin for exporting data as downloadable files.""" site_id: str = "" user_id: UUID | None = None def _available_gene_sets(self) -> list[JSONObject]: """Return summary of available gene sets for error messages.""" store = get_gene_set_store() if self.user_id is not None: sets = store.list_for_user(self.user_id, site_id=self.site_id) else: sets = store.list_all(site_id=self.site_id) return [ {"id": gs.id, "name": gs.name, "geneCount": len(gs.gene_ids)} for gs in sets[:10] ]
[docs] @ai_function() async def export_gene_set( self, gene_set_id: Annotated[str, AIParam(desc="PathFinder gene set ID")], format: Annotated[ str, AIParam(desc="Export format: csv or txt"), ] = "csv", ) -> JSONObject: """Export a gene set as a downloadable CSV or TXT file. Returns a download URL that the user can click to download the file. The URL expires after 10 minutes. """ if format not in ("csv", "txt"): return tool_error( ErrorCode.VALIDATION_ERROR, "format must be 'csv' or 'txt'.", format=format, ) store = get_gene_set_store() gs = await store.aget(gene_set_id) if gs is None: available = self._available_gene_sets() return tool_error( ErrorCode.NOT_FOUND, f"Gene set not found: {gene_set_id}. Use one of the available IDs below.", gene_set_id=gene_set_id, availableGeneSets=cast(JSONValue, available), ) svc = get_export_service() fmt: Literal["csv", "txt"] = "txt" if format == "txt" else "csv" result = await svc.export_gene_set(gs, fmt) return { "downloadUrl": result.url, "filename": result.filename, "format": format, "itemCount": len(gs.gene_ids), "expiresInSeconds": result.expires_in_seconds, }