Source code for veupath_chatbot.services.experiment.ai_analysis_helpers

"""Helper functions for experiment analysis AI tools.

Utility functions for extracting WDK record data, classifying genes,
searching records, and fetching result IDs.
"""

from typing import cast

from veupath_chatbot.integrations.veupathdb.strategy_api import StrategyAPI
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.wdk.helpers import extract_pk

logger = get_logger(__name__)


[docs] def classify_gene( gene_id: str | None, tp_ids: set[str], fp_ids: set[str], fn_ids: set[str], tn_ids: set[str], ) -> str | None: """Return the classification label for a gene ID. :param gene_id: Gene identifier to classify. :param tp_ids: True positive gene IDs. :param fp_ids: False positive gene IDs. :param fn_ids: False negative gene IDs. :param tn_ids: True negative gene IDs. :returns: One of ``"TP"``, ``"FP"``, ``"FN"``, ``"TN"``, or None. """ if not gene_id: return None if gene_id in tp_ids: return "TP" if gene_id in fp_ids: return "FP" if gene_id in fn_ids: return "FN" if gene_id in tn_ids: return "TN" return None
[docs] def record_matches(attrs: JSONObject, query_lower: str, attribute: str | None) -> bool: """Check if a record's attributes match a text query. :param attrs: Record attribute dict. :param query_lower: Lowercased search query. :param attribute: Specific attribute to search in, or None for all. :returns: True if any matching attribute value is found. """ if attribute: val = attrs.get(attribute) return isinstance(val, str) and query_lower in val.lower() return any(isinstance(v, str) and query_lower in v.lower() for v in attrs.values())
[docs] async def build_primary_key( api: StrategyAPI, site_id: str, record_type: str, gene_id: str ) -> list[JSONObject]: """Build a complete WDK primary key for a gene ID. WDK requires all primary key columns (e.g. ``source_id`` + ``project_id`` for gene records). This helper fetches the record type info and fills missing columns from site configuration. :param api: Strategy API instance. :param site_id: VEuPathDB site identifier. :param record_type: WDK record type. :param gene_id: Gene identifier (the ``source_id`` value). :returns: List of ``{name, value}`` dicts forming the complete PK. """ from veupath_chatbot.integrations.veupathdb.factory import get_site pk_parts: list[JSONObject] = [{"name": "source_id", "value": gene_id}] try: info = await api.get_record_type_info(record_type) pk_refs = info.get("primaryKeyColumnRefs") or info.get("primaryKey") or [] if isinstance(pk_refs, list) and len(pk_parts) < len(pk_refs): names_sent = {"source_id"} site = get_site(site_id) pk_defaults: dict[str, str] = {"project_id": site.project_id} for col in pk_refs: if not isinstance(col, str) or col in names_sent: continue default_val = pk_defaults.get(col) if default_val: pk_parts.append({"name": col, "value": default_val}) except Exception as exc: logger.debug( "Failed to build full primary key, falling back to source_id only", gene_id=gene_id, error=str(exc), ) return pk_parts
[docs] async def fetch_group_records( api: StrategyAPI, record_type: str, gene_ids: list[str], limit: int = 20, site_id: str | None = None, ) -> list[JSONObject]: """Fetch records for a list of gene IDs. :param api: Strategy API instance. :param record_type: WDK record type. :param gene_ids: Gene IDs to fetch. :param limit: Max number of genes to fetch. :param site_id: Site ID for PK completion (fills project_id etc.). :returns: List of dicts with ``geneId`` and ``attributes``. """ results: list[JSONObject] = [] for gene_id in gene_ids[:limit]: try: if site_id: pk = await build_primary_key(api, site_id, record_type, gene_id) else: pk = cast( list[JSONObject], [{"name": "source_id", "value": gene_id}], ) rec = await api.get_single_record( record_type=record_type, primary_key=pk, ) if isinstance(rec, dict): results.append( { "geneId": gene_id, "attributes": rec.get("attributes", {}), } ) except Exception: continue return results
[docs] async def collect_all_result_ids(api: StrategyAPI, step_id: int) -> set[str]: """Fetch all result gene IDs from a WDK step by paginating. :param api: Strategy API instance. :param step_id: WDK step ID. :returns: Set of all gene IDs in the step's results. """ result_ids: set[str] = set() offset = 0 page_size = 1000 while True: answer = await api.get_step_records( step_id=step_id, attributes=[], pagination={"offset": offset, "numRecords": page_size}, ) records = answer.get("records", []) if not isinstance(records, list) or not records: break for rec in records: if not isinstance(rec, dict): continue gene_id = extract_pk(rec) if gene_id: result_ids.add(gene_id) offset += len(records) if len(records) < page_size: break return result_ids