"""Strategy build service: compile, create/update, and count extraction.
Encapsulates all business logic for building a strategy on WDK. The AI tool
layer delegates to this module and only handles argument parsing and response
formatting.
"""
from dataclasses import dataclass
from typing import Protocol
from veupath_chatbot.domain.strategy.ast import PlanStepNode, StepTreeNode, StrategyAST
from veupath_chatbot.domain.strategy.compile import (
CompilationResult,
StepDecoratorAPI,
StrategyCompilerAPI,
apply_step_decorations,
compile_strategy,
)
from veupath_chatbot.domain.strategy.session import StrategyGraph
from veupath_chatbot.domain.strategy.validate import validate_strategy
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.catalog.searches import (
make_record_type_resolver,
resolve_record_type_from_steps,
)
from veupath_chatbot.services.experiment.helpers import extract_wdk_id
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Protocols — I/O boundaries the build service depends on
# ---------------------------------------------------------------------------
[docs]
class StrategyBuildAPI(StrategyCompilerAPI, StepDecoratorAPI, Protocol):
"""Combined protocol: compile steps + decorate + strategy CRUD.
This is satisfied by the real ``StrategyAPI`` from the integrations layer.
"""
[docs]
async def create_strategy(
self,
step_tree: StepTreeNode,
name: str,
description: str | None = None,
) -> JSONObject: ...
[docs]
async def update_strategy(
self,
strategy_id: int,
step_tree: StepTreeNode | None = None,
name: str | None = None,
) -> JSONObject: ...
[docs]
async def get_strategy(self, strategy_id: int) -> JSONObject: ...
[docs]
async def get_step_count(self, step_id: int) -> int: ...
[docs]
class SiteInfoLike(Protocol):
"""Protocol for site metadata needed by the build service."""
self, strategy_id: int, root_step_id: int | None = None
) -> str: ...
# ---------------------------------------------------------------------------
# Factory helpers — resolve integrations so tool layer doesn't import them
# ---------------------------------------------------------------------------
def _get_build_api(site_id: str) -> StrategyBuildAPI:
"""Get a StrategyBuildAPI for the given site (delegates to integrations)."""
from veupath_chatbot.integrations.veupathdb.factory import get_strategy_api
return get_strategy_api(site_id)
def _get_site_info(site_id: str) -> SiteInfoLike:
"""Get site info for the given site (delegates to integrations)."""
from veupath_chatbot.integrations.veupathdb.factory import get_site
return get_site(site_id)
# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------
[docs]
@dataclass
class BuildResult:
"""Outcome of a successful strategy build."""
wdk_strategy_id: int | None
wdk_url: str | None
root_step_id: int
root_count: int | None
step_count: int
counts: dict[str, int | None]
zero_step_ids: list[str]
compilation: CompilationResult
[docs]
@dataclass
class StepCountResult:
"""Outcome of a step count lookup."""
step_id: int
count: int
# ---------------------------------------------------------------------------
# Root resolution
# ---------------------------------------------------------------------------
[docs]
class RootResolutionError(Exception):
"""Raised when a single root step cannot be resolved from the graph."""
[docs]
def __init__(self, message: str, root_count: int = 0) -> None:
super().__init__(message)
self.root_count = root_count
[docs]
def resolve_root_step(
graph: StrategyGraph,
explicit_root_step_id: str | None,
) -> PlanStepNode:
"""Resolve the root step from the graph.
:param graph: Strategy graph.
:param explicit_root_step_id: Optional explicit root step ID override.
:returns: The resolved root PlanStepNode.
:raises RootResolutionError: When root cannot be determined.
"""
if explicit_root_step_id:
step = graph.get_step(explicit_root_step_id)
if step:
return step
raise RootResolutionError(
f"Explicit root step '{explicit_root_step_id}' not found in graph."
)
if len(graph.roots) == 1:
step = graph.get_step(next(iter(graph.roots)))
if step:
return step
if len(graph.roots) > 1:
raise RootResolutionError(
f"Graph has {len(graph.roots)} subtree roots -- expected exactly 1 to build. "
"Combine them first, or specify root_step_id.",
root_count=len(graph.roots),
)
raise RootResolutionError("No steps in graph. Create steps before building.")
# ---------------------------------------------------------------------------
# Strategy AST creation
# ---------------------------------------------------------------------------
[docs]
def create_strategy_ast(
graph: StrategyGraph,
root_step: object,
strategy_name: str | None,
description: str | None,
) -> StrategyAST:
"""Create and validate a StrategyAST from graph state.
Record type is read from ``graph.record_type`` which must already be
set (either from step creation or auto-resolution in :func:`build_strategy`).
:param graph: Strategy graph.
:param root_step: Root PlanStepNode.
:param strategy_name: Optional strategy name.
:param description: Optional description.
:returns: Validated StrategyAST.
:raises ValueError: When record type is not set or validation fails.
"""
if not isinstance(root_step, PlanStepNode):
raise TypeError(f"Expected PlanStepNode, got {type(root_step).__name__}")
if not graph.record_type:
raise ValueError("Record type could not be inferred for execution.")
strategy = StrategyAST(
record_type=graph.record_type,
root=root_step,
name=strategy_name or graph.name,
description=description,
)
validation_result = validate_strategy(strategy)
if not validation_result.valid:
errors = [
{"path": e.path, "message": e.message} for e in validation_result.errors
]
raise ValueError(f"Strategy validation failed: {errors}")
return strategy
# ---------------------------------------------------------------------------
# WDK create-or-update
# ---------------------------------------------------------------------------
[docs]
async def create_or_update_wdk_strategy(
api: StrategyBuildAPI,
compilation_result: CompilationResult,
strategy: StrategyAST,
existing_wdk_id: int | None,
) -> int | None:
"""Create a new WDK strategy or update an existing one.
If ``existing_wdk_id`` is provided, attempts to update first. Falls back to
creating a new strategy if the update fails (e.g. 404 from WDK).
:returns: The WDK strategy ID, or None if creation failed.
"""
wdk_strategy_id: int | None = None
if existing_wdk_id is not None:
try:
await api.update_strategy(
strategy_id=existing_wdk_id,
step_tree=compilation_result.step_tree,
name=strategy.name or "Untitled Strategy",
)
wdk_strategy_id = existing_wdk_id
logger.info(
"Updated existing WDK strategy",
wdk_strategy_id=existing_wdk_id,
)
except Exception as update_err:
logger.warning(
"Failed to update WDK strategy, will create new",
wdk_strategy_id=existing_wdk_id,
error=str(update_err),
)
wdk_strategy_id = None
if wdk_strategy_id is None:
wdk_result = await api.create_strategy(
step_tree=compilation_result.step_tree,
name=strategy.name or "Untitled Strategy",
description=strategy.description,
)
wdk_strategy_id = extract_wdk_id(wdk_result)
return wdk_strategy_id
# ---------------------------------------------------------------------------
# Step counts extraction
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Full build orchestration
# ---------------------------------------------------------------------------
[docs]
async def build_strategy(
*,
graph: StrategyGraph,
api: StrategyBuildAPI,
site: SiteInfoLike,
site_id: str,
root_step_id: str | None = None,
strategy_name: str | None = None,
description: str | None = None,
) -> BuildResult:
"""Build or update a strategy on WDK.
Orchestrates: root resolution, AST creation, WDK compilation,
create-or-update, step decorations, count extraction, and
graph state mutation.
Record type is auto-resolved from leaf searches via the pre-cached
SearchCatalog — callers never need to supply it.
:raises RootResolutionError: If root step cannot be determined.
:raises ValueError: If validation or record type inference fails.
:raises Exception: On WDK API failures.
"""
strategy = graph.current_strategy
root_step = resolve_root_step(graph, root_step_id)
# Create resolver upfront for both auto-resolution and compilation.
resolver = await make_record_type_resolver(site_id)
# Auto-resolve record type from leaf searches if not set on graph.
if not graph.record_type:
graph.record_type = await resolve_record_type_from_steps(root_step, resolver)
needs_rebuild = root_step is not None and (
not strategy
or (root_step.id is not None and strategy.get_step_by_id(root_step.id) is None)
)
if not strategy or needs_rebuild:
strategy = create_strategy_ast(graph, root_step, strategy_name, description)
graph.current_strategy = strategy
graph.save_history(f"Created strategy: {strategy_name or 'Untitled Strategy'}")
if strategy_name:
strategy.name = strategy_name
graph.name = strategy_name
logger.info("Building strategy", name=strategy.name)
compilation_result = await compile_strategy(
strategy, api, site_id=site_id, resolve_search_record_type=resolver
)
wdk_strategy_id = await create_or_update_wdk_strategy(
api, compilation_result, strategy, graph.wdk_strategy_id
)
compiled_map = {s.local_id: s.wdk_step_id for s in compilation_result.steps}
await apply_step_decorations(strategy, compiled_map, api)
wdk_url = (
site.strategy_url(wdk_strategy_id, compilation_result.root_step_id)
if wdk_strategy_id
else None
)
# Mutate graph state with build results.
graph.wdk_step_ids = dict(compiled_map)
graph.wdk_strategy_id = wdk_strategy_id
# Fetch step counts from WDK.
step_counts: dict[str, int | None] = {}
root_count: int | None = None
if wdk_strategy_id is not None:
try:
strategy_info = await api.get_strategy(wdk_strategy_id)
if isinstance(strategy_info, dict):
step_counts, root_count = extract_step_counts(
strategy_info, compiled_map
)
except Exception as e:
logger.warning("Strategy count lookup failed", error=str(e))
graph.step_counts = step_counts
zeros = sorted([sid for sid, c in step_counts.items() if c == 0])
return BuildResult(
wdk_strategy_id=wdk_strategy_id,
wdk_url=wdk_url,
root_step_id=compilation_result.root_step_id,
root_count=root_count,
step_count=len(compilation_result.steps),
counts=step_counts,
zero_step_ids=zeros,
compilation=compilation_result,
)
# ---------------------------------------------------------------------------
# Result count lookup
# ---------------------------------------------------------------------------
[docs]
async def get_result_count(
api: StrategyBuildAPI,
wdk_step_id: int,
wdk_strategy_id: int | None = None,
) -> StepCountResult:
"""Get the result count for a built WDK step.
First tries to read ``estimatedSize`` from the strategy payload (cheaper),
then falls back to a direct step count query.
:raises TypeError: If strategy payload is malformed.
:raises Exception: On WDK API errors (propagated to caller).
"""
if wdk_strategy_id is not None:
strategy_raw = await api.get_strategy(wdk_strategy_id)
if not isinstance(strategy_raw, dict):
raise TypeError("Expected dict from get_strategy")
steps_raw = strategy_raw.get("steps")
if isinstance(steps_raw, dict):
step_info = steps_raw.get(str(wdk_step_id))
if isinstance(step_info, dict):
estimated_size = step_info.get("estimatedSize")
if isinstance(estimated_size, int):
return StepCountResult(step_id=wdk_step_id, count=estimated_size)
count = await api.get_step_count(wdk_step_id)
return StepCountResult(step_id=wdk_step_id, count=count)
# ---------------------------------------------------------------------------
# Convenience entry points (resolve integrations internally)
# ---------------------------------------------------------------------------
[docs]
async def build_strategy_for_site(
*,
graph: StrategyGraph,
site_id: str,
root_step_id: str | None = None,
strategy_name: str | None = None,
description: str | None = None,
) -> BuildResult:
"""Build a strategy using factory-resolved API and site info.
This is the entry point for the AI tool layer -- it resolves the
integration objects internally so callers don't import from integrations.
"""
api = _get_build_api(site_id)
site = _get_site_info(site_id)
return await build_strategy(
graph=graph,
api=api,
site=site,
site_id=site_id,
root_step_id=root_step_id,
strategy_name=strategy_name,
description=description,
)
[docs]
async def get_result_count_for_site(
site_id: str,
wdk_step_id: int,
wdk_strategy_id: int | None = None,
) -> StepCountResult:
"""Get result count using factory-resolved API.
This is the entry point for the AI tool layer.
"""
api = _get_build_api(site_id)
return await get_result_count(api, wdk_step_id, wdk_strategy_id)