Source code for veupath_chatbot.services.gene_lookup.wdk

"""WDK-based gene search and ID resolution."""

from dataclasses import dataclass
from typing import cast

from veupath_chatbot.integrations.veupathdb.factory import get_wdk_client
from veupath_chatbot.integrations.veupathdb.site_search import strip_html_tags
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject

from .organism import normalize_organism
from .result import DEFAULT_GENE_ATTRIBUTES, build_gene_result

logger = get_logger(__name__)

WDK_WILDCARD_LIMIT = 50

WDK_TEXT_FIELDS_ID: list[str] = ["primary_key", "Alias"]
WDK_TEXT_FIELDS_BROAD: list[str] = [
    "product",
    "name",
    "primary_key",
    "Alias",
    "GOTerms",
    "Notes",
    "Products",
]


[docs] @dataclass class WdkTextResult: """Results from a WDK ``GenesByText`` query.""" records: list[JSONObject] total_count: int
def _parse_wdk_record(rec: JSONObject) -> JSONObject | None: """Parse a WDK record into a standard gene result dict.""" rec_attrs = rec.get("attributes") if not isinstance(rec_attrs, dict): rec_attrs = {} pk = rec.get("id") gene_id = "" if isinstance(pk, list): for elem in pk: if not isinstance(elem, dict): continue name = elem.get("name") val = elem.get("value") if ( name in ("gene_source_id", "source_id", "gene") and isinstance(val, str) and val.strip() ): gene_id = val.strip() break if not gene_id: gene_id = str(rec_attrs.get("primary_key", "")).strip() gene_name_raw = rec_attrs.get("gene_name", "") gene_name = strip_html_tags(str(gene_name_raw)) if gene_name_raw else "" product_raw = rec_attrs.get("gene_product", "") product = strip_html_tags(str(product_raw)) if product_raw else "" organism_raw = rec_attrs.get("organism", "") org = normalize_organism(str(organism_raw)) if organism_raw else "" gene_type = str(rec_attrs.get("gene_type", "")) location = str(rec_attrs.get("gene_location_text", "")) gene_source_id = str(rec_attrs.get("gene_source_id", "")) previous_ids = str(rec_attrs.get("gene_previous_ids", "")) return build_gene_result( gene_id=gene_source_id or gene_id, display_name=gene_name or product or gene_id, organism=org, product=product, gene_name=gene_name, gene_type=gene_type, location=location, previous_ids=previous_ids or "", )
[docs] async def fetch_wdk_text_genes( site_id: str, expressions: list[str], *, organism: str | None = None, text_fields: list[str] | None = None, record_type: str = "transcript", limit: int = WDK_WILDCARD_LIMIT, ) -> WdkTextResult: """Search genes via WDK ``GenesByText``.""" if not expressions or not organism: return WdkTextResult(records=[], total_count=0) import json fields = text_fields or WDK_TEXT_FIELDS_ID client = get_wdk_client(site_id) all_results: list[JSONObject] = [] wdk_total: int = 0 for pattern in expressions: answer = await client.post( f"/record-types/{record_type}/searches/GenesByText/reports/standard", json=cast( JSONObject, { "searchConfig": { "parameters": { "text_expression": pattern, "text_fields": json.dumps(fields), "text_search_organism": json.dumps([organism]), "document_type": "gene", }, }, "reportConfig": { "attributes": DEFAULT_GENE_ATTRIBUTES, "tables": [], "pagination": {"offset": 0, "numRecords": limit}, }, }, ), ) if not isinstance(answer, dict): continue meta = answer.get("meta") if isinstance(meta, dict): mt = meta.get("totalCount") if isinstance(mt, int): wdk_total = max(wdk_total, mt) raw_records = answer.get("records") if not isinstance(raw_records, list): continue for rec in raw_records: if not isinstance(rec, dict): continue parsed = _parse_wdk_record(rec) if parsed: all_results.append(parsed) if len(all_results) >= limit: break records = all_results[:limit] return WdkTextResult( records=records, total_count=max(wdk_total, len(records)), )
[docs] async def resolve_gene_ids( site_id: str, gene_ids: list[str], *, record_type: str = "transcript", search_name: str = "GeneByLocusTag", param_name: str = "ds_gene_ids", attributes: list[str] | None = None, ) -> JSONObject: """Resolve a list of gene IDs to full records via the WDK standard reporter. Uses a dedicated short-lived WDK client to guarantee session affinity between dataset creation and the subsequent search. The shared singleton client's cookie jar is modified by concurrent requests, which can cause the dataset to "not belong" to the search session (WDK tracks anonymous users via session cookies). """ if not gene_ids: return {"records": [], "totalCount": 0} from veupath_chatbot.integrations.veupathdb.client import VEuPathDBClient from veupath_chatbot.integrations.veupathdb.site_router import get_site_router from veupath_chatbot.platform.config import get_settings router = get_site_router() site = router.get_site(site_id) settings = get_settings() routing = router._config.routing timeout = float( routing.portal_timeout if site.is_portal else routing.component_timeout ) client = VEuPathDBClient( base_url=site.service_url, timeout=timeout, auth_token=settings.veupathdb_auth_token, ) attrs = attributes or DEFAULT_GENE_ATTRIBUTES try: dataset_resp = await client.post( "/users/current/datasets", json=cast( JSONObject, {"sourceType": "idList", "sourceContent": {"ids": gene_ids}}, ), ) if not isinstance(dataset_resp, dict): return { "records": [], "totalCount": 0, "error": "Failed to create dataset for ID lookup.", } dataset_id = dataset_resp.get("id") if dataset_id is None: return { "records": [], "totalCount": 0, "error": "Dataset creation returned no ID.", } answer = await client.post( f"/record-types/{record_type}/searches/{search_name}/reports/standard", json=cast( JSONObject, { "searchConfig": { "parameters": {param_name: str(dataset_id)}, }, "reportConfig": { "attributes": attrs, "tables": [], }, }, ), ) except Exception as exc: logger.warning( "Gene ID resolution via standard reporter failed", site_id=site_id, gene_ids_count=len(gene_ids), error=str(exc), ) return { "records": [], "totalCount": 0, "error": f"WDK lookup failed: {exc}", } finally: await client.close() if not isinstance(answer, dict): return {"records": [], "totalCount": 0} raw_records = answer.get("records") if not isinstance(raw_records, list): return {"records": [], "totalCount": 0} records: list[JSONObject] = [] for rec in raw_records: if not isinstance(rec, dict): continue parsed = _parse_wdk_record(rec) if parsed: records.append(parsed) meta = answer.get("meta") or {} total = ( meta.get("totalCount", len(records)) if isinstance(meta, dict) else len(records) ) return cast(JSONObject, {"records": records, "totalCount": total})