"""AST node types for strategy representation (WDK-aligned, untyped tree)."""
from dataclasses import dataclass, field
from typing import Literal, cast
from uuid import uuid4
from veupath_chatbot.domain.strategy.ops import ColocationParams, CombineOp
from veupath_chatbot.platform.types import JSONObject, JSONValue
[docs]
class StepTreeNode:
"""Node in a WDK step tree.
Represents a single step with optional primary (and for combines, secondary)
input references. Used to build the ``stepTree`` payload for WDK strategy
creation. Pure data structure with no I/O.
"""
[docs]
def __init__(
self,
step_id: int,
primary_input: StepTreeNode | None = None,
secondary_input: StepTreeNode | None = None,
) -> None:
self.step_id = step_id
self.primary_input = primary_input
self.secondary_input = secondary_input
[docs]
def to_dict(self) -> JSONObject:
"""Convert to WDK stepTree format."""
result: JSONObject = {"stepId": self.step_id}
if self.primary_input:
result["primaryInput"] = self.primary_input.to_dict()
if self.secondary_input:
result["secondaryInput"] = self.secondary_input.to_dict()
return result
[docs]
def generate_step_id() -> str:
"""Generate a unique step ID."""
return f"step_{uuid4().hex[:8]}"
# ---------------------------------------------------------------------------
# Shared parsers for filters, analyses, reports, and colocation params.
# Used by both from_dict() and session.hydrate_graph_from_steps_data().
# ---------------------------------------------------------------------------
[docs]
def parse_filters(raw: JSONValue) -> list[StepFilter]:
"""Parse a list of step filters from raw JSON data."""
items = raw if isinstance(raw, list) else []
filters: list[StepFilter] = []
for item in items:
if not isinstance(item, dict):
continue
name = item.get("name")
if not isinstance(name, str) or not name:
continue
filters.append(
StepFilter(
name=name,
value=item.get("value"),
disabled=bool(item.get("disabled", False)),
)
)
return filters
[docs]
def parse_analyses(raw: JSONValue) -> list[StepAnalysis]:
"""Parse a list of step analyses from raw JSON data."""
items = raw if isinstance(raw, list) else []
analyses: list[StepAnalysis] = []
for item in items:
if not isinstance(item, dict):
continue
analysis_type = item.get("analysisType") or item.get("analysis_type")
if not isinstance(analysis_type, str) or not analysis_type:
continue
params_raw = item.get("parameters")
params: JSONObject = params_raw if isinstance(params_raw, dict) else {}
custom_name_raw = item.get("customName") or item.get("custom_name")
custom_name = custom_name_raw if isinstance(custom_name_raw, str) else None
analyses.append(
StepAnalysis(
analysis_type=analysis_type,
parameters=params,
custom_name=custom_name,
)
)
return analyses
[docs]
def parse_reports(raw: JSONValue) -> list[StepReport]:
"""Parse a list of step reports from raw JSON data."""
items = raw if isinstance(raw, list) else []
reports: list[StepReport] = []
for item in items:
if not isinstance(item, dict):
continue
report_name_raw = item.get("reportName") or item.get("report_name")
report_name = (
report_name_raw if isinstance(report_name_raw, str) else "standard"
)
config_raw = item.get("config")
config: JSONObject = config_raw if isinstance(config_raw, dict) else {}
reports.append(
StepReport(
report_name=report_name,
config=config,
)
)
return reports
[docs]
def parse_colocation_params(raw: JSONValue) -> ColocationParams | None:
"""Parse colocation parameters from raw JSON data."""
if not isinstance(raw, dict):
return None
upstream_raw = raw.get("upstream", 0)
downstream_raw = raw.get("downstream", 0)
strand_raw = raw.get("strand", "both")
upstream = int(upstream_raw) if isinstance(upstream_raw, (int, float)) else 0
downstream = int(downstream_raw) if isinstance(downstream_raw, (int, float)) else 0
strand: Literal["same", "opposite", "both"]
if isinstance(strand_raw, str) and strand_raw in ("same", "opposite", "both"):
strand = cast(Literal["same", "opposite", "both"], strand_raw)
else:
strand = "both"
return ColocationParams(upstream=upstream, downstream=downstream, strand=strand)
[docs]
@dataclass
class StepFilter:
"""Filter applied to a step's result."""
name: str
value: JSONValue
disabled: bool = False
[docs]
def to_dict(self) -> JSONObject:
return {
"name": self.name,
"value": self.value,
"disabled": self.disabled,
}
[docs]
@dataclass
class StepAnalysis:
"""Analysis configuration for a step."""
analysis_type: str
parameters: JSONObject = field(default_factory=dict)
custom_name: str | None = None
[docs]
def to_dict(self) -> JSONObject:
result: JSONObject = {
"analysisType": self.analysis_type,
"parameters": self.parameters,
}
if self.custom_name:
result["customName"] = self.custom_name
return result
[docs]
@dataclass
class StepReport:
"""Report request attached to a step."""
report_name: str = "standard"
config: JSONObject = field(default_factory=dict)
[docs]
def to_dict(self) -> JSONObject:
return {
"reportName": self.report_name,
"config": self.config,
}
[docs]
@dataclass
class PlanStepNode:
"""Untyped recursive strategy node.
Kind is inferred from structure:
- combine: primary_input and secondary_input
- transform: primary_input only
- search: no inputs
"""
search_name: str
parameters: JSONObject = field(default_factory=dict)
primary_input: PlanStepNode | None = None
secondary_input: PlanStepNode | None = None
operator: CombineOp | None = None
colocation_params: ColocationParams | None = None
display_name: str | None = None
filters: list[StepFilter] = field(default_factory=list)
analyses: list[StepAnalysis] = field(default_factory=list)
reports: list[StepReport] = field(default_factory=list)
wdk_weight: int | None = None
id: str = field(default_factory=generate_step_id)
[docs]
def infer_kind(self) -> str:
if self.primary_input is not None and self.secondary_input is not None:
return "combine"
if self.primary_input is not None:
return "transform"
return "search"
[docs]
def to_dict(self) -> JSONObject:
result: JSONObject = {
"id": self.id,
"searchName": self.search_name,
"displayName": self.display_name or self.search_name,
"parameters": self.parameters or {},
}
if self.primary_input is not None:
result["primaryInput"] = self.primary_input.to_dict()
if self.secondary_input is not None:
result["secondaryInput"] = self.secondary_input.to_dict()
if self.operator is not None:
result["operator"] = self.operator.value
if self.colocation_params is not None:
result["colocationParams"] = {
"upstream": self.colocation_params.upstream,
"downstream": self.colocation_params.downstream,
"strand": self.colocation_params.strand,
}
if self.filters:
result["filters"] = [f.to_dict() for f in self.filters]
if self.analyses:
result["analyses"] = [a.to_dict() for a in self.analyses]
if self.reports:
result["reports"] = [r.to_dict() for r in self.reports]
if self.wdk_weight is not None:
result["wdkWeight"] = self.wdk_weight
return result
[docs]
@dataclass
class StrategyAST:
"""Complete strategy represented as an AST."""
record_type: str
root: PlanStepNode
name: str | None = None
description: str | None = None
metadata: JSONObject | None = None
[docs]
def to_dict(self) -> JSONObject:
"""Convert to dictionary representation."""
metadata: JSONObject = dict(self.metadata or {})
# Ensure name/description are always reflected in metadata.
if self.name is not None:
metadata["name"] = self.name
if self.description is not None:
metadata["description"] = self.description
return {
"recordType": self.record_type,
"root": self.root.to_dict(),
"metadata": metadata or None,
}
[docs]
def get_all_steps(self) -> list[PlanStepNode]:
"""Get all steps in the tree (depth-first)."""
steps: list[PlanStepNode] = []
def visit(node: PlanStepNode) -> None:
if node.primary_input is not None:
visit(node.primary_input)
if node.secondary_input is not None:
visit(node.secondary_input)
steps.append(node)
visit(self.root)
return steps
[docs]
def get_step_by_id(self, step_id: str) -> PlanStepNode | None:
"""Find a step by its ID.
:param step_id: Step identifier.
"""
for step in self.get_all_steps():
if step.id == step_id:
return step
return None
[docs]
def from_dict(data: JSONObject) -> StrategyAST:
"""Parse strategy from dictionary representation.
:param data: Data dict.
"""
def parse_node(node_data: JSONObject) -> PlanStepNode:
search_name = node_data.get("searchName")
if not isinstance(search_name, str) or not search_name:
raise ValueError("Missing searchName")
params = node_data.get("parameters") or {}
if not isinstance(params, dict):
raise ValueError("parameters must be an object")
primary_raw = node_data.get("primaryInput")
secondary_raw = node_data.get("secondaryInput")
primary = parse_node(primary_raw) if isinstance(primary_raw, dict) else None
secondary = (
parse_node(secondary_raw) if isinstance(secondary_raw, dict) else None
)
op_raw = node_data.get("operator")
operator = CombineOp(op_raw) if isinstance(op_raw, str) and op_raw else None
colocation = parse_colocation_params(node_data.get("colocationParams"))
# basic structural constraints
if secondary is not None and primary is None:
raise ValueError("secondaryInput requires primaryInput")
if secondary is not None and operator is None:
raise ValueError("operator is required when secondaryInput is present")
if operator == CombineOp.COLOCATE and colocation is None:
raise ValueError("colocationParams is required when operator is COLOCATE")
if operator != CombineOp.COLOCATE and colocation is not None:
raise ValueError(
"colocationParams is only allowed when operator is COLOCATE"
)
display_name_raw = node_data.get("displayName")
display_name = display_name_raw if isinstance(display_name_raw, str) else None
id_raw = node_data.get("id")
step_id = id_raw if isinstance(id_raw, str) else generate_step_id()
wdk_weight_raw = node_data.get("wdkWeight")
wdk_weight = (
int(wdk_weight_raw) if isinstance(wdk_weight_raw, (int, float)) else None
)
return PlanStepNode(
search_name=search_name,
parameters=params,
primary_input=primary,
secondary_input=secondary,
operator=operator,
colocation_params=colocation,
display_name=display_name,
filters=parse_filters(node_data.get("filters")),
analyses=parse_analyses(node_data.get("analyses")),
reports=parse_reports(node_data.get("reports")),
wdk_weight=wdk_weight,
id=step_id,
)
record_type_raw = data.get("recordType")
if not isinstance(record_type_raw, str):
raise ValueError("Missing or invalid recordType")
root_raw = data.get("root")
if not isinstance(root_raw, dict):
raise ValueError("Missing or invalid root")
metadata_raw = data.get("metadata", {})
metadata_obj = metadata_raw if isinstance(metadata_raw, dict) else {}
name_raw = metadata_obj.get("name")
name = name_raw if isinstance(name_raw, str) else None
description_raw = metadata_obj.get("description")
description = description_raw if isinstance(description_raw, str) else None
return StrategyAST(
record_type=record_type_raw,
root=parse_node(root_raw),
name=name,
description=description,
metadata=metadata_obj if metadata_obj else None,
)