Source code for veupath_chatbot.services.strategies.build

"""Strategy build service: compile, create/update, and count extraction.

Encapsulates all business logic for building a strategy on WDK. The AI tool
layer delegates to this module and only handles argument parsing and response
formatting.
"""

from dataclasses import dataclass
from typing import Protocol

from veupath_chatbot.domain.strategy.ast import PlanStepNode, StepTreeNode, StrategyAST
from veupath_chatbot.domain.strategy.compile import (
    CompilationResult,
    StepDecoratorAPI,
    StrategyCompilerAPI,
    apply_step_decorations,
    compile_strategy,
)
from veupath_chatbot.domain.strategy.session import StrategyGraph
from veupath_chatbot.domain.strategy.validate import validate_strategy
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.catalog.searches import (
    make_record_type_resolver,
    resolve_record_type_from_steps,
)
from veupath_chatbot.services.experiment.helpers import extract_wdk_id

logger = get_logger(__name__)


# ---------------------------------------------------------------------------
# Protocols — I/O boundaries the build service depends on
# ---------------------------------------------------------------------------


[docs] class StrategyBuildAPI(StrategyCompilerAPI, StepDecoratorAPI, Protocol): """Combined protocol: compile steps + decorate + strategy CRUD. This is satisfied by the real ``StrategyAPI`` from the integrations layer. """
[docs] async def create_strategy(
self, step_tree: StepTreeNode, name: str, description: str | None = None, ) -> JSONObject: ...
[docs] async def update_strategy(
self, strategy_id: int, step_tree: StepTreeNode | None = None, name: str | None = None, ) -> JSONObject: ...
[docs] async def get_strategy(self, strategy_id: int) -> JSONObject: ...
[docs] async def get_step_count(self, step_id: int) -> int: ...
[docs] class SiteInfoLike(Protocol): """Protocol for site metadata needed by the build service."""
[docs] def strategy_url(
self, strategy_id: int, root_step_id: int | None = None ) -> str: ...
# --------------------------------------------------------------------------- # Factory helpers — resolve integrations so tool layer doesn't import them # --------------------------------------------------------------------------- def _get_build_api(site_id: str) -> StrategyBuildAPI: """Get a StrategyBuildAPI for the given site (delegates to integrations).""" from veupath_chatbot.integrations.veupathdb.factory import get_strategy_api return get_strategy_api(site_id) def _get_site_info(site_id: str) -> SiteInfoLike: """Get site info for the given site (delegates to integrations).""" from veupath_chatbot.integrations.veupathdb.factory import get_site return get_site(site_id) # --------------------------------------------------------------------------- # Result types # ---------------------------------------------------------------------------
[docs] @dataclass class BuildResult: """Outcome of a successful strategy build.""" wdk_strategy_id: int | None wdk_url: str | None root_step_id: int root_count: int | None step_count: int counts: dict[str, int | None] zero_step_ids: list[str] compilation: CompilationResult
[docs] @dataclass class StepCountResult: """Outcome of a step count lookup.""" step_id: int count: int
# --------------------------------------------------------------------------- # Root resolution # ---------------------------------------------------------------------------
[docs] class RootResolutionError(Exception): """Raised when a single root step cannot be resolved from the graph."""
[docs] def __init__(self, message: str, root_count: int = 0) -> None: super().__init__(message) self.root_count = root_count
[docs] def resolve_root_step( graph: StrategyGraph, explicit_root_step_id: str | None, ) -> PlanStepNode: """Resolve the root step from the graph. :param graph: Strategy graph. :param explicit_root_step_id: Optional explicit root step ID override. :returns: The resolved root PlanStepNode. :raises RootResolutionError: When root cannot be determined. """ if explicit_root_step_id: step = graph.get_step(explicit_root_step_id) if step: return step raise RootResolutionError( f"Explicit root step '{explicit_root_step_id}' not found in graph." ) if len(graph.roots) == 1: step = graph.get_step(next(iter(graph.roots))) if step: return step if len(graph.roots) > 1: raise RootResolutionError( f"Graph has {len(graph.roots)} subtree roots -- expected exactly 1 to build. " "Combine them first, or specify root_step_id.", root_count=len(graph.roots), ) raise RootResolutionError("No steps in graph. Create steps before building.")
# --------------------------------------------------------------------------- # Strategy AST creation # ---------------------------------------------------------------------------
[docs] def create_strategy_ast( graph: StrategyGraph, root_step: object, strategy_name: str | None, description: str | None, ) -> StrategyAST: """Create and validate a StrategyAST from graph state. Record type is read from ``graph.record_type`` which must already be set (either from step creation or auto-resolution in :func:`build_strategy`). :param graph: Strategy graph. :param root_step: Root PlanStepNode. :param strategy_name: Optional strategy name. :param description: Optional description. :returns: Validated StrategyAST. :raises ValueError: When record type is not set or validation fails. """ if not isinstance(root_step, PlanStepNode): raise TypeError(f"Expected PlanStepNode, got {type(root_step).__name__}") if not graph.record_type: raise ValueError("Record type could not be inferred for execution.") strategy = StrategyAST( record_type=graph.record_type, root=root_step, name=strategy_name or graph.name, description=description, ) validation_result = validate_strategy(strategy) if not validation_result.valid: errors = [ {"path": e.path, "message": e.message} for e in validation_result.errors ] raise ValueError(f"Strategy validation failed: {errors}") return strategy
# --------------------------------------------------------------------------- # WDK create-or-update # ---------------------------------------------------------------------------
[docs] async def create_or_update_wdk_strategy( api: StrategyBuildAPI, compilation_result: CompilationResult, strategy: StrategyAST, existing_wdk_id: int | None, ) -> int | None: """Create a new WDK strategy or update an existing one. If ``existing_wdk_id`` is provided, attempts to update first. Falls back to creating a new strategy if the update fails (e.g. 404 from WDK). :returns: The WDK strategy ID, or None if creation failed. """ wdk_strategy_id: int | None = None if existing_wdk_id is not None: try: await api.update_strategy( strategy_id=existing_wdk_id, step_tree=compilation_result.step_tree, name=strategy.name or "Untitled Strategy", ) wdk_strategy_id = existing_wdk_id logger.info( "Updated existing WDK strategy", wdk_strategy_id=existing_wdk_id, ) except Exception as update_err: logger.warning( "Failed to update WDK strategy, will create new", wdk_strategy_id=existing_wdk_id, error=str(update_err), ) wdk_strategy_id = None if wdk_strategy_id is None: wdk_result = await api.create_strategy( step_tree=compilation_result.step_tree, name=strategy.name or "Untitled Strategy", description=strategy.description, ) wdk_strategy_id = extract_wdk_id(wdk_result) return wdk_strategy_id
# --------------------------------------------------------------------------- # Step counts extraction # ---------------------------------------------------------------------------
[docs] def extract_step_counts( strategy_info: JSONObject, compiled_map: dict[str, int], ) -> tuple[dict[str, int | None], int | None]: """Extract per-step result counts from a WDK strategy payload. :param strategy_info: Raw WDK strategy dict (from ``api.get_strategy``). :param compiled_map: Mapping of local_step_id -> wdk_step_id. :returns: Tuple of (step_counts dict, root_count). """ step_counts: dict[str, int | None] = {} root_count: int | None = None root_step_id_raw = strategy_info.get("rootStepId") steps_raw = strategy_info.get("steps") if not isinstance(steps_raw, dict): return step_counts, root_count wdk_to_local = {v: k for k, v in compiled_map.items()} for wdk_id_str, step_info in steps_raw.items(): if not isinstance(step_info, dict): continue estimated = step_info.get("estimatedSize") count_val = estimated if isinstance(estimated, int) else None try: wdk_id_int = int(wdk_id_str) except ValueError, TypeError: continue local_id = wdk_to_local.get(wdk_id_int) if local_id: step_counts[local_id] = count_val if isinstance(root_step_id_raw, int): root_local = wdk_to_local.get(root_step_id_raw) if root_local: root_count = step_counts.get(root_local) return step_counts, root_count
# --------------------------------------------------------------------------- # Full build orchestration # ---------------------------------------------------------------------------
[docs] async def build_strategy( *, graph: StrategyGraph, api: StrategyBuildAPI, site: SiteInfoLike, site_id: str, root_step_id: str | None = None, strategy_name: str | None = None, description: str | None = None, ) -> BuildResult: """Build or update a strategy on WDK. Orchestrates: root resolution, AST creation, WDK compilation, create-or-update, step decorations, count extraction, and graph state mutation. Record type is auto-resolved from leaf searches via the pre-cached SearchCatalog — callers never need to supply it. :raises RootResolutionError: If root step cannot be determined. :raises ValueError: If validation or record type inference fails. :raises Exception: On WDK API failures. """ strategy = graph.current_strategy root_step = resolve_root_step(graph, root_step_id) # Create resolver upfront for both auto-resolution and compilation. resolver = await make_record_type_resolver(site_id) # Auto-resolve record type from leaf searches if not set on graph. if not graph.record_type: graph.record_type = await resolve_record_type_from_steps(root_step, resolver) needs_rebuild = root_step is not None and ( not strategy or (root_step.id is not None and strategy.get_step_by_id(root_step.id) is None) ) if not strategy or needs_rebuild: strategy = create_strategy_ast(graph, root_step, strategy_name, description) graph.current_strategy = strategy graph.save_history(f"Created strategy: {strategy_name or 'Untitled Strategy'}") if strategy_name: strategy.name = strategy_name graph.name = strategy_name logger.info("Building strategy", name=strategy.name) compilation_result = await compile_strategy( strategy, api, site_id=site_id, resolve_search_record_type=resolver ) wdk_strategy_id = await create_or_update_wdk_strategy( api, compilation_result, strategy, graph.wdk_strategy_id ) compiled_map = {s.local_id: s.wdk_step_id for s in compilation_result.steps} await apply_step_decorations(strategy, compiled_map, api) wdk_url = ( site.strategy_url(wdk_strategy_id, compilation_result.root_step_id) if wdk_strategy_id else None ) # Mutate graph state with build results. graph.wdk_step_ids = dict(compiled_map) graph.wdk_strategy_id = wdk_strategy_id # Fetch step counts from WDK. step_counts: dict[str, int | None] = {} root_count: int | None = None if wdk_strategy_id is not None: try: strategy_info = await api.get_strategy(wdk_strategy_id) if isinstance(strategy_info, dict): step_counts, root_count = extract_step_counts( strategy_info, compiled_map ) except Exception as e: logger.warning("Strategy count lookup failed", error=str(e)) graph.step_counts = step_counts zeros = sorted([sid for sid, c in step_counts.items() if c == 0]) return BuildResult( wdk_strategy_id=wdk_strategy_id, wdk_url=wdk_url, root_step_id=compilation_result.root_step_id, root_count=root_count, step_count=len(compilation_result.steps), counts=step_counts, zero_step_ids=zeros, compilation=compilation_result, )
# --------------------------------------------------------------------------- # Result count lookup # ---------------------------------------------------------------------------
[docs] async def get_result_count( api: StrategyBuildAPI, wdk_step_id: int, wdk_strategy_id: int | None = None, ) -> StepCountResult: """Get the result count for a built WDK step. First tries to read ``estimatedSize`` from the strategy payload (cheaper), then falls back to a direct step count query. :raises TypeError: If strategy payload is malformed. :raises Exception: On WDK API errors (propagated to caller). """ if wdk_strategy_id is not None: strategy_raw = await api.get_strategy(wdk_strategy_id) if not isinstance(strategy_raw, dict): raise TypeError("Expected dict from get_strategy") steps_raw = strategy_raw.get("steps") if isinstance(steps_raw, dict): step_info = steps_raw.get(str(wdk_step_id)) if isinstance(step_info, dict): estimated_size = step_info.get("estimatedSize") if isinstance(estimated_size, int): return StepCountResult(step_id=wdk_step_id, count=estimated_size) count = await api.get_step_count(wdk_step_id) return StepCountResult(step_id=wdk_step_id, count=count)
# --------------------------------------------------------------------------- # Convenience entry points (resolve integrations internally) # ---------------------------------------------------------------------------
[docs] async def build_strategy_for_site( *, graph: StrategyGraph, site_id: str, root_step_id: str | None = None, strategy_name: str | None = None, description: str | None = None, ) -> BuildResult: """Build a strategy using factory-resolved API and site info. This is the entry point for the AI tool layer -- it resolves the integration objects internally so callers don't import from integrations. """ api = _get_build_api(site_id) site = _get_site_info(site_id) return await build_strategy( graph=graph, api=api, site=site, site_id=site_id, root_step_id=root_step_id, strategy_name=strategy_name, description=description, )
[docs] async def get_result_count_for_site( site_id: str, wdk_step_id: int, wdk_strategy_id: int | None = None, ) -> StepCountResult: """Get result count using factory-resolved API. This is the entry point for the AI tool layer. """ api = _get_build_api(site_id) return await get_result_count(api, wdk_step_id, wdk_strategy_id)