Source code for veupath_chatbot.services.strategies.wdk_counts

"""Step count computation: get per-step result counts from WDK.

Supports two paths:
- **Leaf-only strategies**: parallel anonymous reports (fast, no strategy creation)
- **Complex strategies**: temporary WDK strategy compilation (slower)

Results are cached by plan hash to avoid redundant API calls.
"""

import asyncio
import hashlib
import json
from collections import OrderedDict

from veupath_chatbot.domain.strategy.ast import StrategyAST
from veupath_chatbot.domain.strategy.compile import compile_strategy
from veupath_chatbot.integrations.veupathdb.client import VEuPathDBClient
from veupath_chatbot.integrations.veupathdb.factory import get_strategy_api
from veupath_chatbot.integrations.veupathdb.strategy_api import StrategyAPI
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.control_helpers import delete_temp_strategy
from veupath_chatbot.services.experiment.helpers import extract_wdk_id

logger = get_logger(__name__)

_STEP_COUNTS_CACHE: OrderedDict[str, dict[str, int | None]] = OrderedDict()
_STEP_COUNTS_CACHE_MAX = 20


[docs] def plan_cache_key(site_id: str, plan: JSONObject) -> str: payload = json.dumps(plan, sort_keys=True, separators=(",", ":")) digest = hashlib.sha256(payload.encode()).hexdigest() return f"{site_id}:{digest}"
async def _count_via_anonymous_report( client: VEuPathDBClient, record_type: str, search_name: str, parameters: JSONObject, ) -> int | None: """Get result count for a single search using the anonymous report endpoint. ``POST /record-types/{recordType}/searches/{searchName}/reports/standard`` with ``numRecords: 0`` returns only ``meta.totalCount`` — no step or strategy creation needed. Returns ``None`` on failure. """ search_config: JSONObject = {"parameters": parameters} report_config: JSONObject = {"pagination": {"offset": 0, "numRecords": 0}} try: result = await client.run_search_report( record_type, search_name, search_config, report_config ) meta = result.get("meta") if isinstance(meta, dict): count = meta.get("totalCount") if isinstance(count, int): return count except Exception as e: logger.warning( "Anonymous report count failed", record_type=record_type, search_name=search_name, error=str(e), ) return None
[docs] def is_leaf_only_strategy(strategy_ast: StrategyAST) -> bool: """Check if all steps in the strategy are leaf (search) steps.""" return all(step.infer_kind() == "search" for step in strategy_ast.get_all_steps())
[docs] async def compute_step_counts_for_plan( plan: JSONObject, strategy_ast: StrategyAST, site_id: str, ) -> dict[str, int | None]: """Compute per-step result counts for a strategy plan. For **leaf-only strategies** (all search steps, no combines/transforms), uses WDK's anonymous report endpoint in parallel — no step or strategy creation needed. This is dramatically faster than full compilation. For **complex strategies** (with combines/transforms), falls back to creating a temporary WDK strategy to get server-computed counts. Results are cached by plan hash. """ cache_key = plan_cache_key(site_id, plan) cached = _STEP_COUNTS_CACHE.get(cache_key) if cached is not None: _STEP_COUNTS_CACHE.move_to_end(cache_key) return cached api = get_strategy_api(site_id) # Fast path: leaf-only strategies use parallel anonymous reports. if is_leaf_only_strategy(strategy_ast): counts = await _compute_leaf_counts_parallel(api.client, strategy_ast) _cache_counts(cache_key, counts) return counts # Slow path: complex strategies require full WDK compilation. counts = await _compute_counts_via_compilation(api, strategy_ast, site_id) _cache_counts(cache_key, counts) return counts
async def _compute_leaf_counts_parallel( client: VEuPathDBClient, strategy_ast: StrategyAST, ) -> dict[str, int | None]: """Compute counts for all leaf steps in parallel using anonymous reports.""" all_steps = strategy_ast.get_all_steps() record_type = strategy_ast.record_type tasks = [ _count_via_anonymous_report( client, record_type, step.search_name, step.parameters or {} ) for step in all_steps ] results = await asyncio.gather(*tasks) return {step.id: count for step, count in zip(all_steps, results, strict=True)} async def _compute_counts_via_compilation( api: StrategyAPI, strategy_ast: StrategyAST, site_id: str, ) -> dict[str, int | None]: """Compute counts by creating a temporary WDK strategy (legacy path).""" result = await compile_strategy( strategy_ast, api, site_id=site_id, resolve_record_type=True, ) temp_strategy_id: int | None = None try: created = await api.create_strategy( step_tree=result.step_tree, name="Pathfinder step counts", description=None, is_internal=True, ) temp_strategy_id = extract_wdk_id(created) except Exception as exc: logger.error( "Failed to create temporary WDK strategy for step counts", error=str(exc), site_id=site_id, step_count=len(result.steps), ) counts: dict[str, int | None] = {step.local_id: None for step in result.steps} if temp_strategy_id is not None: try: wdk_strategy = await api.get_strategy(temp_strategy_id) if isinstance(wdk_strategy, dict): steps_dict = wdk_strategy.get("steps") if isinstance(steps_dict, dict): for step in result.steps: step_info = steps_dict.get(str(step.wdk_step_id)) if isinstance(step_info, dict): estimated = step_info.get("estimatedSize") if isinstance(estimated, int): counts[step.local_id] = estimated except Exception as e: logger.warning("Failed to read counts from strategy payload", error=str(e)) await delete_temp_strategy(api, temp_strategy_id) return counts def _cache_counts(cache_key: str, counts: dict[str, int | None]) -> None: """Store counts in the LRU cache.""" _STEP_COUNTS_CACHE[cache_key] = counts if len(_STEP_COUNTS_CACHE) > _STEP_COUNTS_CACHE_MAX: _STEP_COUNTS_CACHE.popitem(last=False)