"""Step creation business logic.
Extracts the validation-heavy step creation workflow from the AI tool layer
so it can be tested and reused independently. All I/O dependencies (WDK
client, discovery service) are injected via callbacks or explicit parameters.
"""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Literal, cast
from veupath_chatbot.domain.parameters.specs import (
adapt_param_specs,
find_input_step_param,
unwrap_search_data,
)
from veupath_chatbot.domain.strategy.ast import PlanStepNode
from veupath_chatbot.domain.strategy.ops import ColocationParams, CombineOp, parse_op
from veupath_chatbot.domain.strategy.session import StrategyGraph
from veupath_chatbot.integrations.veupathdb.factory import get_wdk_client
from veupath_chatbot.platform.errors import ErrorCode, ValidationError
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.tool_errors import tool_error
from veupath_chatbot.platform.types import JSONObject, JSONValue
from veupath_chatbot.services.catalog.param_validation import validate_parameters
logger = get_logger(__name__)
COMBINE_PLACEHOLDER_SEARCH_NAME = "__combine__"
# Callback type aliases for injected dependencies.
ResolveRecordTypeFn = Callable[
[str | None, str | None, bool, bool], Awaitable[str | None]
]
FindRecordTypeHintFn = Callable[[str, str | None], Awaitable[str | None]]
ExtractVocabOptionsFn = Callable[[JSONObject], list[str]]
[docs]
@dataclass
class StepCreationResult:
"""Result of step creation: either a successfully added step or an error payload."""
step: PlanStepNode | None
step_id: str | None
error: JSONObject | None
[docs]
def coerce_wdk_boolean_question_params(
*,
parameters: JSONObject,
) -> tuple[str | None, str | None, str | None]:
"""Extract left/right/operator from WDK boolean-question parameter conventions.
WDK boolean questions sometimes encode combines as a "boolean_question_*"
search with ``bq_left_op_``, ``bq_right_op_``, ``bq_operator``. Our graph
model represents combines structurally; we translate these WDK boolean
keys from the parameters dict.
Mutates *parameters* by removing any ``bq_*`` keys it consumes.
:param parameters: WDK boolean-question parameters (may contain
``bq_left_op_``, ``bq_right_op_``, ``bq_operator``).
:returns: Tuple of (left_step_id, right_step_id, operator) or (None, None, None).
"""
if not isinstance(parameters, dict) or not parameters:
return None, None, None
left_id: str | None = None
right_id: str | None = None
op: str | None = None
for k, v in list(parameters.items()):
key = str(k)
if left_id is None and key.startswith("bq_left_op"):
if v is not None:
left_id = str(v)
parameters.pop(k, None)
continue
if right_id is None and key.startswith("bq_right_op"):
if v is not None:
right_id = str(v)
parameters.pop(k, None)
continue
# operator can be provided as bq_operator; strip it from params in favor of structural operator
if "bq_operator" in parameters:
raw = parameters.pop("bq_operator", None)
if raw is not None:
op = str(raw)
if left_id and right_id and op:
return left_id, right_id, op
return None, None, None
def _find_consumer(graph: StrategyGraph, step_id: str) -> str | None:
"""Find the step that already consumes *step_id* as an input."""
return next(
(
s.id
for s in graph.steps.values()
if (
getattr(getattr(s, "primary_input", None), "id", None) == step_id
or getattr(getattr(s, "secondary_input", None), "id", None) == step_id
)
),
None,
)
def _validate_inputs(
graph: StrategyGraph,
primary_input_step_id: str | None,
secondary_input_step_id: str | None,
operator: str | None,
) -> tuple[PlanStepNode | None, PlanStepNode | None, JSONObject | None]:
"""Validate and resolve input step references.
:returns: (primary_input, secondary_input, error_or_none).
If error is not None, the caller should return it immediately.
"""
primary_input = None
secondary_input = None
if primary_input_step_id:
primary_input = graph.get_step(primary_input_step_id)
if not primary_input:
return (
None,
None,
tool_error(
ErrorCode.STEP_NOT_FOUND,
"Primary input step not found.",
graphId=graph.id,
stepId=primary_input_step_id,
),
)
if secondary_input_step_id:
secondary_input = graph.get_step(secondary_input_step_id)
if not secondary_input:
return (
None,
None,
tool_error(
ErrorCode.STEP_NOT_FOUND,
"Secondary input step not found.",
graphId=graph.id,
stepId=secondary_input_step_id,
),
)
if primary_input is None:
return (
None,
None,
tool_error(
ErrorCode.INVALID_STRATEGY,
"secondary_input_step_id requires primary_input_step_id.",
graphId=graph.id,
),
)
if not operator:
return (
None,
None,
tool_error(
ErrorCode.INVALID_STRATEGY,
"operator is required when secondary_input_step_id is provided.",
graphId=graph.id,
),
)
return primary_input, secondary_input, None
def _validate_root_status(
graph: StrategyGraph,
step_id: str,
label: str,
) -> JSONObject | None:
"""Check that *step_id* is a subtree root. Return an error payload if not."""
if step_id in graph.roots:
return None
consumer = _find_consumer(graph, step_id)
return tool_error(
ErrorCode.INVALID_STRATEGY,
f"Step '{step_id}' is not a subtree root — it is already consumed by step '{consumer}'. "
"Only current subtree roots can be used as inputs.",
graphId=graph.id,
stepId=step_id,
consumedBy=consumer,
availableRoots=cast(JSONValue, sorted(graph.roots)),
)
async def _resolve_and_set_record_type(
graph: StrategyGraph,
record_type: str | None,
search_name: str | None,
resolve_record_type_for_search: ResolveRecordTypeFn,
) -> str:
"""Establish best-effort record type context on the graph. Returns the resolved type."""
resolved = graph.record_type or record_type
if resolved is None and search_name:
resolved = await resolve_record_type_for_search(None, search_name, False, True)
if resolved is None:
resolved = "gene"
graph.record_type = resolved
return resolved
async def _validate_leaf_step(
*,
graph: StrategyGraph,
site_id: str,
resolved_record_type: str,
search_name: str,
parameters: JSONObject,
resolve_record_type_for_search: ResolveRecordTypeFn,
find_record_type_hint: FindRecordTypeHintFn,
extract_vocab_options: ExtractVocabOptionsFn,
validation_error_payload: Callable[[ValidationError], JSONObject],
) -> JSONObject | None:
"""Validate a leaf step (no inputs). Returns error payload or None on success."""
rt = await resolve_record_type_for_search(
resolved_record_type, search_name, True, True
)
if rt is None:
record_type_hint = await find_record_type_hint(
search_name, resolved_record_type
)
return tool_error(
ErrorCode.SEARCH_NOT_FOUND,
f"Unknown or invalid search: {search_name}",
recordType=resolved_record_type,
recordTypeHint=record_type_hint,
)
graph.record_type = rt
try:
await validate_parameters(
site_id=site_id,
record_type=rt,
search_name=search_name,
parameters=parameters,
resolve_record_type_for_search=resolve_record_type_for_search,
find_record_type_hint=find_record_type_hint,
extract_vocab_options=extract_vocab_options,
)
except ValidationError as exc:
return validation_error_payload(exc)
# Guard: fold-change searches with identical ref and comp samples produce
# meaningless results. Catch this early so the model can fix it.
ref = parameters.get("samples_fc_ref_generic") or parameters.get(
"samples_percentile_generic"
)
comp = parameters.get("samples_fc_comp_generic")
if ref and comp and str(ref) == str(comp):
return tool_error(
ErrorCode.VALIDATION_ERROR,
"Reference and comparison samples are identical — this will produce "
"meaningless fold-change results. Set different samples for reference "
"vs comparison.",
searchName=search_name,
ref=ref,
comp=comp,
)
return None
async def _validate_transform_step(
*,
graph: StrategyGraph,
site_id: str,
resolved_record_type: str,
search_name: str,
parameters: JSONObject,
resolve_record_type_for_search: ResolveRecordTypeFn,
find_record_type_hint: FindRecordTypeHintFn,
extract_vocab_options: ExtractVocabOptionsFn,
validation_error_payload: Callable[[ValidationError], JSONObject],
) -> JSONObject | None:
"""Validate a transform step (primary input only). Returns error payload or None."""
rt = await resolve_record_type_for_search(
resolved_record_type, search_name, True, True
)
if rt is None:
record_type_hint = await find_record_type_hint(
search_name, resolved_record_type
)
return tool_error(
ErrorCode.SEARCH_NOT_FOUND,
f"Unknown or invalid search: {search_name}",
recordType=resolved_record_type,
recordTypeHint=record_type_hint,
)
graph.record_type = rt
# Validate parameters against WDK specs.
try:
await validate_parameters(
site_id=site_id,
record_type=rt,
search_name=search_name,
parameters=parameters,
resolve_record_type_for_search=resolve_record_type_for_search,
find_record_type_hint=find_record_type_hint,
extract_vocab_options=extract_vocab_options,
)
except ValidationError as exc:
return validation_error_payload(exc)
# Confirm the question supports an input step.
try:
wdk = get_wdk_client(site_id)
details = await wdk.get_search_details(rt, search_name, expand_params=True)
except Exception as exc:
return tool_error(
ErrorCode.VALIDATION_ERROR,
"Failed to load search metadata for transform validation.",
recordType=rt,
searchName=search_name,
detail=str(exc),
)
specs = adapt_param_specs(unwrap_search_data(details) or {})
input_param = find_input_step_param(specs)
if not input_param:
return tool_error(
ErrorCode.INVALID_STRATEGY,
f"Search '{search_name}' cannot be used as a transform: it does not accept an input step. "
f"Call list_transforms(record_type='{rt}') to see available transforms (e.g. GenesByOrthologs for ortholog conversion).",
recordType=rt,
searchName=search_name,
suggestedFix={
"message": "Call list_transforms() to find the correct transform search, or create this as a leaf step and combine with INTERSECT/UNION/MINUS.",
"asLeaf": {"searchName": search_name, "recordType": rt},
},
)
return None
def _build_colocation_params(
operator: CombineOp | None,
upstream: int | None,
downstream: int | None,
strand: str | None,
) -> ColocationParams | None:
"""Build ColocationParams if the operator is COLOCATE."""
if operator != CombineOp.COLOCATE:
return None
strand_value: Literal["same", "opposite", "both"]
if strand in ("same", "opposite", "both"):
strand_value = cast(Literal["same", "opposite", "both"], strand)
else:
strand_value = "both"
return ColocationParams(
upstream=upstream or 0,
downstream=downstream or 0,
strand=strand_value,
)
[docs]
async def create_step(
*,
graph: StrategyGraph,
site_id: str,
search_name: str | None = None,
parameters: JSONObject | None = None,
record_type: str | None = None,
primary_input_step_id: str | None = None,
secondary_input_step_id: str | None = None,
operator: str | None = None,
display_name: str | None = None,
upstream: int | None = None,
downstream: int | None = None,
strand: str | None = None,
resolve_record_type_for_search: ResolveRecordTypeFn,
find_record_type_hint: FindRecordTypeHintFn,
extract_vocab_options: ExtractVocabOptionsFn,
validation_error_payload: Callable[[ValidationError], JSONObject],
) -> StepCreationResult:
"""Create a new strategy step with full validation.
Step kind is inferred from structure:
- leaf step: no inputs
- unary/transform step: primary_input_step_id only
- binary/combine step: primary_input_step_id + secondary_input_step_id (+ operator)
:returns: StepCreationResult with either step/step_id or error.
"""
parameters = parameters or {}
# WDK compatibility: if caller encoded a boolean combine as WDK boolean-question
# parameters, translate it into structural inputs.
if not primary_input_step_id and not secondary_input_step_id and not operator:
left, right, op = coerce_wdk_boolean_question_params(parameters=parameters)
if left and right and op:
primary_input_step_id = left
secondary_input_step_id = right
operator = op
# Validate and resolve input steps.
primary_input, secondary_input, error = _validate_inputs(
graph, primary_input_step_id, secondary_input_step_id, operator
)
if error is not None:
return StepCreationResult(step=None, step_id=None, error=error)
# Validate root status for input steps.
if (
primary_input is not None
and primary_input_step_id is not None
and primary_input_step_id not in graph.roots
):
root_error = _validate_root_status(graph, primary_input_step_id, "primary")
if root_error is not None:
return StepCreationResult(step=None, step_id=None, error=root_error)
if (
secondary_input is not None
and secondary_input_step_id is not None
and secondary_input_step_id not in graph.roots
):
root_error = _validate_root_status(graph, secondary_input_step_id, "secondary")
if root_error is not None:
return StepCreationResult(step=None, step_id=None, error=root_error)
# Resolve record type.
resolved_record_type = await _resolve_and_set_record_type(
graph, record_type, search_name, resolve_record_type_for_search
)
# Determine search_name requirement.
is_binary = primary_input is not None and secondary_input is not None
if not search_name:
if is_binary:
search_name = COMBINE_PLACEHOLDER_SEARCH_NAME
else:
return StepCreationResult(
step=None,
step_id=None,
error=tool_error(
ErrorCode.INVALID_STRATEGY,
"search_name is required for leaf and transform steps.",
graphId=graph.id,
),
)
# Validate leaf steps (no inputs).
if primary_input is None and secondary_input is None:
leaf_error = await _validate_leaf_step(
graph=graph,
site_id=site_id,
resolved_record_type=resolved_record_type,
search_name=search_name,
parameters=parameters,
resolve_record_type_for_search=resolve_record_type_for_search,
find_record_type_hint=find_record_type_hint,
extract_vocab_options=extract_vocab_options,
validation_error_payload=validation_error_payload,
)
if leaf_error is not None:
return StepCreationResult(step=None, step_id=None, error=leaf_error)
# Validate transform steps (primary input only, no secondary).
if primary_input is not None and secondary_input is None:
transform_error = await _validate_transform_step(
graph=graph,
site_id=site_id,
resolved_record_type=resolved_record_type,
search_name=search_name,
parameters=parameters,
resolve_record_type_for_search=resolve_record_type_for_search,
find_record_type_hint=find_record_type_hint,
extract_vocab_options=extract_vocab_options,
validation_error_payload=validation_error_payload,
)
if transform_error is not None:
return StepCreationResult(step=None, step_id=None, error=transform_error)
# Parse operator and build colocation params.
op = parse_op(operator) if secondary_input is not None and operator else None
colocation = _build_colocation_params(op, upstream, downstream, strand)
# Build and add the step.
step = PlanStepNode(
search_name=search_name,
parameters=parameters,
primary_input=primary_input,
secondary_input=secondary_input,
operator=op,
colocation_params=colocation,
display_name=display_name or search_name,
)
step_id = graph.add_step(step)
logger.info("Created step", step_id=step_id, search=search_name)
return StepCreationResult(step=step, step_id=step_id, error=None)