Source code for veupath_chatbot.services.wdk.enrichment_service

"""Unified enrichment service.

Single entry point for running enrichment analyses regardless of
whether the caller is an experiment endpoint, gene set endpoint,
or AI tool.

Rate limiting
-------------
A process-level semaphore (``_WDK_ENRICHMENT_SEMAPHORE``) limits how
many ``run_batch`` calls can execute concurrently across the entire
application.  Within a single batch, analyses run in parallel via
``asyncio.gather`` to keep total wall-clock time within proxy timeouts.
"""

import asyncio

from veupath_chatbot.domain.strategy.ast import StepTreeNode
from veupath_chatbot.integrations.veupathdb.factory import get_strategy_api
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.control_helpers import delete_temp_strategy
from veupath_chatbot.services.experiment.enrichment import (
    _execute_analysis,
    run_enrichment_analysis,
    run_enrichment_on_step,
)
from veupath_chatbot.services.experiment.helpers import coerce_step_id, extract_wdk_id
from veupath_chatbot.services.experiment.types import (
    EnrichmentAnalysisType,
    EnrichmentResult,
)

logger = get_logger(__name__)

# Limit concurrent enrichment batches process-wide.
# WDK's step analysis API becomes unreliable under parallel load.
# This limits how many run_batch calls execute simultaneously, not
# individual analyses within a batch.
_WDK_ENRICHMENT_SEMAPHORE = asyncio.Semaphore(3)


[docs] class EnrichmentService: """Unified enrichment dispatcher."""
[docs] async def run( self, *, site_id: str, analysis_type: EnrichmentAnalysisType, step_id: int | None = None, search_name: str | None = None, record_type: str | None = None, parameters: JSONObject | None = None, ) -> EnrichmentResult: """Run a single enrichment analysis. If step_id is provided, runs on the existing step. Otherwise creates a temporary strategy from search_name/parameters. """ if step_id is not None: return await run_enrichment_on_step( site_id=site_id, step_id=step_id, analysis_type=analysis_type, ) if search_name and parameters is not None: return await run_enrichment_analysis( site_id=site_id, record_type=record_type or "transcript", search_name=search_name, parameters=parameters, analysis_type=analysis_type, ) raise ValueError("Either step_id or search_name+parameters required")
[docs] async def run_batch( self, *, site_id: str, analysis_types: list[EnrichmentAnalysisType], step_id: int | None = None, search_name: str | None = None, record_type: str | None = None, parameters: JSONObject | None = None, ) -> tuple[list[EnrichmentResult], list[str]]: """Run multiple enrichment analyses concurrently on a shared step. When no step_id is provided (paste gene sets), creates ONE temporary WDK step/strategy and runs all analysis types against it — instead of creating N separate temp strategies. This reduces WDK API calls from ~5N to ~N+3 and avoids rate-limit 500s. """ errors: list[str] = [] # If we already have a step, run all analyses on it directly. if step_id is not None: async with _WDK_ENRICHMENT_SEMAPHORE: return await self._run_analyses_on_step( site_id, step_id, analysis_types, errors, ) # No step — need search_name + parameters to create one. if not search_name or parameters is None: raise ValueError("Either step_id or search_name+parameters required") # Create ONE temp step/strategy, run all analyses, then clean up. api = get_strategy_api(site_id) step = await api.create_step( record_type=record_type or "transcript", search_name=search_name, parameters=parameters or {}, custom_name="Enrichment target", ) shared_step_id = coerce_step_id(step) root = StepTreeNode(shared_step_id) strategy_id: int | None = None async with _WDK_ENRICHMENT_SEMAPHORE: try: created = await api.create_strategy( step_tree=root, name="Pathfinder enrichment analysis", description=None, is_internal=True, ) strategy_id = extract_wdk_id(created) return await self._run_analyses_on_step( site_id, shared_step_id, analysis_types, errors, ) finally: await delete_temp_strategy(api, strategy_id)
async def _run_analyses_on_step( self, site_id: str, step_id: int, analysis_types: list[EnrichmentAnalysisType], errors: list[str], ) -> tuple[list[EnrichmentResult], list[str]]: """Run multiple analysis types on a single step concurrently. Analyses run in parallel to keep total wall-clock time under proxy timeouts (~30s instead of ~90s sequential). A process-level semaphore still limits how many ``run_batch`` calls execute concurrently across different requests. """ api = get_strategy_api(site_id) async def _run_one( analysis_type: EnrichmentAnalysisType, ) -> EnrichmentResult: try: return await _execute_analysis(api, step_id, analysis_type) except Exception as exc: logger.warning( "Enrichment failed", analysis_type=analysis_type, error=str(exc), ) error_msg = str(exc) errors.append(f"{analysis_type}: {error_msg}") return EnrichmentResult( analysis_type=analysis_type, terms=[], total_genes_analyzed=0, background_size=0, error=error_msg, ) results = list(await asyncio.gather(*[_run_one(t) for t in analysis_types])) return results, errors