"""Compile strategy AST to WDK API calls."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from veupath_chatbot.domain.parameters.normalize import ParameterNormalizer
from veupath_chatbot.domain.parameters.specs import (
adapt_param_specs,
find_input_step_param,
unwrap_search_data,
)
from veupath_chatbot.domain.strategy.ast import (
PlanStepNode,
StepTreeNode,
StrategyAST,
)
from veupath_chatbot.domain.strategy.ops import CombineOp, get_wdk_operator
from veupath_chatbot.platform.errors import InternalError, ValidationError
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject, JSONValue
# Callback type: given a search name, returns the owning record type (or None).
# Mirrors WDK's WdkModel.getQuestionByName() — a global lookup across all
# record types. Callers inject an implementation backed by the pre-cached
# SearchCatalog so no HTTP calls are needed at compile time.
ResolveRecordType = Callable[[str], Awaitable[str | None]]
# ---------------------------------------------------------------------------
# Protocol: I/O boundary the compiler depends on
# ---------------------------------------------------------------------------
class _CompilerClient(Protocol):
"""Protocol for the WDK HTTP client methods the compiler needs."""
async def get_search_details(
self, record_type: str, search_name: str, expand_params: bool = False
) -> JSONObject: ...
async def get_search_details_with_params(
self,
record_type: str,
search_name: str,
context: JSONObject,
expand_params: bool = False,
) -> JSONObject: ...
[docs]
@runtime_checkable
class StrategyCompilerAPI(Protocol):
"""I/O boundary for strategy compilation.
The compiler calls these methods to create WDK steps and datasets.
Any object that satisfies this protocol can be injected -- the real
:class:`StrategyAPI` from the integrations layer is one such object.
"""
@property
def client(self) -> _CompilerClient: ...
self,
record_type: str,
search_name: str,
parameters: JSONObject,
custom_name: str | None = None,
wdk_weight: int | None = None,
) -> JSONObject: ...
[docs]
async def create_combined_step(
self,
primary_step_id: int,
secondary_step_id: int,
boolean_operator: str,
record_type: str,
custom_name: str | None = None,
wdk_weight: int | None = None,
) -> JSONObject: ...
self,
input_step_id: int,
transform_name: str,
parameters: JSONObject,
record_type: str = "transcript",
custom_name: str | None = None,
wdk_weight: int | None = None,
) -> JSONObject: ...
[docs]
async def create_dataset(self, ids: list[str]) -> int: ...
[docs]
@runtime_checkable
class StepDecoratorAPI(Protocol):
"""I/O boundary for post-compilation step decorations (filters, analyses, reports)."""
[docs]
async def set_step_filter(
self, step_id: int, filter_name: str, value: JSONValue, disabled: bool = False
) -> JSONValue: ...
[docs]
async def run_step_analysis(
self,
step_id: int,
analysis_type: str,
parameters: JSONObject | None = None,
custom_name: str | None = None,
) -> JSONObject: ...
[docs]
async def run_step_report(
self, step_id: int, report_name: str, config: JSONObject | None = None
) -> JSONValue: ...
logger = get_logger(__name__)
[docs]
@dataclass
class CompiledStep:
"""A compiled step with WDK step ID."""
local_id: str
wdk_step_id: int
step_type: str
display_name: str
[docs]
@dataclass
class CompilationResult:
"""Result of compiling a strategy to WDK."""
steps: list[CompiledStep]
step_tree: StepTreeNode
root_step_id: int
[docs]
def to_dict(self) -> JSONObject:
"""Convert to dictionary."""
return {
"steps": [
{
"localId": s.local_id,
"wdkStepId": s.wdk_step_id,
"type": s.step_type,
"displayName": s.display_name,
}
for s in self.steps
],
"stepTree": self.step_tree.to_dict(),
"rootStepId": self.root_step_id,
}
def _extract_wdk_step_id(result: JSONObject) -> int:
"""Extract and validate a numeric WDK step ID from an API response."""
wdk_step_id_value = result.get("id")
if not isinstance(wdk_step_id_value, (int, float)):
raise ValueError(f"Expected numeric step ID, got {wdk_step_id_value}")
return int(wdk_step_id_value)
[docs]
class StrategyCompiler:
"""Compiles strategy AST to WDK API calls."""
[docs]
def __init__(
self,
api: StrategyCompilerAPI,
site_id: str | None = None,
resolve_record_type: bool = True,
resolve_search_record_type: ResolveRecordType | None = None,
) -> None:
self.api = api
self._compiled_steps: dict[str, CompiledStep] = {}
self.site_id = site_id
self.resolve_record_type = resolve_record_type
self._resolve_search_rt = resolve_search_record_type
[docs]
async def compile(self, strategy: StrategyAST) -> CompilationResult:
"""Compile strategy to WDK steps and tree.
This creates all steps via the WDK API and builds
the step tree structure for creating the strategy.
"""
logger.info("Compiling strategy", record_type=strategy.record_type)
self._compiled_steps = {}
if self.resolve_record_type:
resolved_record_type = await self._resolve_strategy_record_type(strategy)
if resolved_record_type and resolved_record_type != strategy.record_type:
strategy.record_type = resolved_record_type
# Compile the tree recursively (creates steps bottom-up)
step_tree = await self._compile_node(strategy.root, strategy.record_type)
# Get root step ID
root_step = self._compiled_steps.get(strategy.root.id)
if not root_step:
raise InternalError(
title="Strategy compilation failed",
detail="Failed to compile root step.",
)
return CompilationResult(
steps=list(self._compiled_steps.values()),
step_tree=step_tree,
root_step_id=root_step.wdk_step_id,
)
async def _compile_node(self, node: PlanStepNode, record_type: str) -> StepTreeNode:
"""Compile a single node, returning its step tree."""
kind = node.infer_kind()
if kind == "search":
return await self._compile_search(node, record_type)
if kind == "combine":
return await self._compile_combine(node, record_type)
if kind == "transform":
return await self._compile_transform(node, record_type)
raise ValueError(f"Unknown node kind: {kind}")
async def _resolve_strategy_record_type(self, strategy: StrategyAST) -> str | None:
"""Resolve the strategy-level record type from leaf searches.
Uses the injected callback to look up each leaf search's record type,
then returns the common record type if all agree.
"""
if self._resolve_search_rt is None:
return None
search_steps = [
step.search_name
for step in strategy.get_all_steps()
if step.infer_kind() == "search"
]
if not search_steps:
return None
resolved_types: set[str] = set()
for search_name in search_steps:
rt = await self._resolve_search_rt(search_name)
if not rt:
return None
resolved_types.add(rt)
if len(resolved_types) == 1:
return next(iter(resolved_types))
if len(resolved_types) > 1:
raise ValidationError(
title="Strategy mixes record types",
detail="Searches in this strategy belong to multiple record types and cannot be combined.",
errors=[
{"recordType": record_type}
for record_type in sorted(resolved_types)
],
)
return None
async def _compile_search(
self, step: PlanStepNode, record_type: str
) -> StepTreeNode:
"""Compile a search step."""
logger.debug("Compiling search step", step_id=step.id, search=step.search_name)
search_rt = await self._resolve_search_record_type(
step.search_name, record_type
)
parameters = await self._coerce_parameters(
search_rt, step.search_name, step.parameters
)
result = await self.api.create_step(
record_type=search_rt,
search_name=step.search_name,
parameters=parameters,
custom_name=step.display_name,
wdk_weight=step.wdk_weight,
)
wdk_step_id = _extract_wdk_step_id(result)
self._compiled_steps[step.id] = CompiledStep(
local_id=step.id,
wdk_step_id=wdk_step_id,
step_type="search",
display_name=step.display_name or step.search_name,
)
return StepTreeNode(step_id=wdk_step_id)
async def _compile_combine(
self, step: PlanStepNode, record_type: str
) -> StepTreeNode:
"""Compile a combine step."""
if not step.primary_input or not step.secondary_input:
raise ValueError("Combine step missing inputs")
if step.operator is None:
raise ValueError("Combine step missing operator")
logger.debug("Compiling combine step", step_id=step.id, op=step.operator.value)
left_tree = await self._compile_node(step.primary_input, record_type)
right_tree = await self._compile_node(step.secondary_input, record_type)
if step.operator == CombineOp.COLOCATE:
result = await self._compile_colocation(
step, left_tree.step_id, right_tree.step_id, record_type
)
else:
wdk_op = get_wdk_operator(step.operator)
result = await self.api.create_combined_step(
primary_step_id=left_tree.step_id,
secondary_step_id=right_tree.step_id,
boolean_operator=wdk_op,
record_type=record_type,
custom_name=step.display_name,
wdk_weight=step.wdk_weight,
)
wdk_step_id = _extract_wdk_step_id(result)
self._compiled_steps[step.id] = CompiledStep(
local_id=step.id,
wdk_step_id=wdk_step_id,
step_type="combine",
display_name=step.display_name or f"{step.operator.value}",
)
return StepTreeNode(
step_id=wdk_step_id, primary_input=left_tree, secondary_input=right_tree
)
async def _compile_colocation(
self,
step: PlanStepNode,
left_step_id: int,
right_step_id: int,
record_type: str,
) -> JSONObject:
"""Compile a colocation operation using WDK's GenesBySpanLogic search.
GenesBySpanLogic ("Genes by Relative Location") has two ``input-step``
params (``span_a``, ``span_b``). WDK requires them to be ``""`` at
step creation — the actual input wiring happens via the ``stepTree``
when the strategy is created/updated (``primaryInput`` → ``span_a``,
``secondaryInput`` → ``span_b``).
The ``span_sentence`` param is vestigial but **required** (WDK rejects
empty); the frontend always sets it to ``"sentence"``.
"""
logger.debug(
"Compiling colocation",
left=left_step_id,
right=right_step_id,
upstream=step.colocation_params.upstream if step.colocation_params else 0,
downstream=step.colocation_params.downstream
if step.colocation_params
else 0,
)
params: JSONObject = {
"span_sentence": "sentence",
"span_operation": "overlap",
"span_strand": "Both strands",
"span_output": "a",
"region_a": "upstream",
"region_b": "exact",
"span_begin_a": "start",
"span_begin_direction_a": "-",
"span_begin_offset_a": str(
step.colocation_params.upstream if step.colocation_params else 0
),
"span_end_a": "start",
"span_end_direction_a": "-",
"span_end_offset_a": str(
step.colocation_params.downstream if step.colocation_params else 0
),
"span_begin_b": "start",
"span_begin_direction_b": "-",
"span_begin_offset_b": "0",
"span_end_b": "stop",
"span_end_direction_b": "-",
"span_end_offset_b": "0",
}
return await self.api.create_transform_step(
input_step_id=left_step_id,
transform_name="GenesBySpanLogic",
parameters=params,
record_type=record_type,
custom_name=step.display_name or "Genomic colocation",
wdk_weight=step.wdk_weight,
)
async def _resolve_search_record_type(
self, search_name: str, default_record_type: str
) -> str:
"""Resolve the record type for a search.
Uses the injected ``resolve_search_record_type`` callback (backed by
the pre-cached SearchCatalog) which mirrors WDK's global
``getQuestionByName()`` lookup. Falls back to the strategy-level
default when no callback is available.
"""
if self._resolve_search_rt is not None:
resolved = await self._resolve_search_rt(search_name)
if resolved:
return resolved
return default_record_type
async def _compile_transform(
self, step: PlanStepNode, record_type: str
) -> StepTreeNode:
"""Compile a transform step."""
if not step.primary_input:
raise ValueError("Transform step missing primaryInput")
logger.debug(
"Compiling transform step", step_id=step.id, transform=step.search_name
)
input_tree = await self._compile_node(step.primary_input, record_type)
transform_rt = await self._resolve_search_record_type(
step.search_name, record_type
)
parameters = await self._coerce_parameters(
transform_rt, step.search_name, step.parameters
)
result = await self.api.create_transform_step(
input_step_id=input_tree.step_id,
transform_name=step.search_name,
parameters=parameters,
record_type=transform_rt,
custom_name=step.display_name,
wdk_weight=step.wdk_weight,
)
wdk_step_id = _extract_wdk_step_id(result)
self._compiled_steps[step.id] = CompiledStep(
local_id=step.id,
wdk_step_id=wdk_step_id,
step_type="transform",
display_name=step.display_name or step.search_name,
)
return StepTreeNode(step_id=wdk_step_id, primary_input=input_tree)
async def _coerce_parameters(
self, record_type: str, search_name: str, parameters: JSONObject
) -> JSONObject:
try:
details = await self.api.client.get_search_details(
record_type, search_name, expand_params=True
)
except Exception as exc:
raise ValidationError(
title="Failed to load search metadata",
detail=f"Unable to load parameter metadata for '{search_name}' ({record_type}).",
errors=[{"searchName": search_name, "recordType": record_type}],
) from exc
details = unwrap_search_data(details) or {}
specs = adapt_param_specs(details)
normalizer = ParameterNormalizer(specs)
try:
normalized = normalizer.normalize(parameters or {})
except ValidationError as validation_exc:
# WDK question param metadata can be context-dependent. When validation fails, retry
# with contextParamValues so dependent vocabularies/constraints can refresh.
#
# Some WDK deployments/questions error on this POST (500 Internal Error). If that
# happens, keep the original specs and re-raise the validation error.
try:
details = await self.api.client.get_search_details_with_params(
record_type,
search_name,
context=parameters or {},
expand_params=True,
)
except Exception as exc:
raise validation_exc from exc
details = unwrap_search_data(details) or {}
specs = adapt_param_specs(details)
normalizer = ParameterNormalizer(specs)
normalized = normalizer.normalize(parameters or {})
input_step_param = find_input_step_param(specs)
if input_step_param:
normalized[input_step_param] = ""
# Auto-upload datasets for input-dataset params.
# WDK DatasetParam fields expect an integer dataset ID. If the LLM provided
# raw record IDs (e.g. gene locus tags) instead, upload them as a transient
# dataset and swap in the integer ID. This is infrastructure plumbing — the
# LLM explicitly chose the IDs; we just translate the transport format.
for param_name, spec in specs.items():
if spec.param_type != "input-dataset":
continue
raw_value = normalized.get(param_name)
if raw_value is None:
continue
str_value = str(raw_value).strip()
# Already a valid dataset ID (integer)?
try:
int(str_value)
continue # nothing to do
except ValueError, TypeError:
pass
# Looks like raw IDs — split on commas/newlines and upload.
ids = [
tok.strip()
for tok in str_value.replace("\n", ",").split(",")
if tok.strip()
]
if not ids:
continue
dataset_id = await self.api.create_dataset(ids)
normalized[param_name] = str(dataset_id)
logger.info(
"Uploaded dataset for input-dataset param",
param=param_name,
id_count=len(ids),
dataset_id=dataset_id,
)
return normalized
[docs]
async def apply_step_decorations(
strategy: StrategyAST,
compiled_map: dict[str, int],
api: StepDecoratorAPI,
) -> None:
"""Apply filters, analyses, and reports to compiled WDK steps.
Post-compilation step: walks the strategy AST and applies any
declared decorations (filters, analyses, reports) to each step's
WDK counterpart.
"""
for step in strategy.get_all_steps():
wdk_step_id = compiled_map.get(step.id)
if wdk_step_id is None:
continue
for step_filter in step.filters:
await api.set_step_filter(
step_id=wdk_step_id,
filter_name=step_filter.name,
value=step_filter.value,
disabled=step_filter.disabled,
)
for analysis in step.analyses:
await api.run_step_analysis(
step_id=wdk_step_id,
analysis_type=analysis.analysis_type,
parameters=analysis.parameters,
custom_name=analysis.custom_name,
)
for report in step.reports:
await api.run_step_report(
step_id=wdk_step_id,
report_name=report.report_name,
config=report.config,
)
[docs]
async def compile_strategy(
strategy: StrategyAST,
api: StrategyCompilerAPI,
site_id: str | None = None,
resolve_record_type: bool = True,
resolve_search_record_type: ResolveRecordType | None = None,
) -> CompilationResult:
"""Compile a strategy AST to WDK."""
compiler = StrategyCompiler(
api,
site_id=site_id,
resolve_record_type=resolve_record_type,
resolve_search_record_type=resolve_search_record_type,
)
return await compiler.compile(strategy)