Source code for veupath_chatbot.ai.orchestration.delegation

"""Normalize and validate delegation inputs (nested plan structure).

This is AI-orchestration logic: it validates a model-produced *nested* plan
into a strict, executable shape.
"""

from collections.abc import Callable
from dataclasses import dataclass

from veupath_chatbot.domain.strategy.ops import parse_op
from veupath_chatbot.platform.tool_errors import tool_error
from veupath_chatbot.platform.types import (
    JSONArray,
    JSONObject,
    JSONValue,
    as_json_object,
)


[docs] @dataclass(frozen=True) class DelegationPlan: goal: str tasks: JSONArray combines: JSONArray nodes_by_id: dict[str, JSONObject] dependents: dict[str, list[str]]
def _op_value(value: JSONValue) -> str | None: if value is None: return None try: return parse_op(str(value).strip()).value except Exception: return None
[docs] def build_delegation_plan( *, goal: str, plan: JSONObject | None, ) -> DelegationPlan | JSONObject: def plan_error(message: str, detail: str, **extra: JSONValue) -> JSONObject: return tool_error( "DELEGATION_PLAN_INVALID", message, goal=goal, detail=detail, **extra, ) if not isinstance(plan, dict): return plan_error( "plan is required when delegating.", "Provide a nested plan object as 'plan'.", ) # Compile nested plan -> DAG nodes. node_counter = 0 tasks: JSONArray = [] combines: JSONArray = [] # Structural dedupe: canonical node signature -> generated id seen_signatures: dict[str, str] = {} def new_id() -> str: nonlocal node_counter node_counter += 1 return f"node_{node_counter}" def _canon(value: JSONValue) -> JSONValue: """Best-effort canonicalization for hashing. :param value: Value to process. """ if isinstance(value, dict): return { str(k): _canon(v) for k, v in sorted(value.items(), key=lambda kv: str(kv[0])) } if isinstance(value, list): return [_canon(v) for v in value] if isinstance(value, str): return value.strip() return value def _get_field( node: JSONObject, *keys: str, default: JSONValue = None ) -> JSONValue: """Get field from node with multiple possible keys. :param node: Node dict to read from. :param keys: Possible keys to try. :param default: Default value if no key found. :returns: Value at first matching key, or default. """ for key in keys: if key in node: return node[key] return default def _compile_dependencies( *nodes: JSONValue, ) -> tuple[list[str], JSONObject | None]: """Compile child nodes and return their IDs, or an error. :param nodes: Child nodes to compile. :returns: Tuple of dependency IDs and optional error. """ dep_ids: list[str] = [] for node in nodes: if node is None: continue child_id = compile_node(node) if isinstance(child_id, dict): return [], child_id dep_ids.append(child_id) return dep_ids, None def _get_or_create_node_id( signature_obj: JSONObject, target_list: JSONArray, node_data: JSONObject, task_formatter: Callable[[str], str] | None = None, ) -> str: """Get existing node ID from signature or create new one. :param signature_obj: Object to use for signature matching. :param target_list: List to append new nodes to. :param node_data: Node data dict (will be modified with id). :param task_formatter: Optional function to format task field using node_id. :returns: Node ID (existing or newly created). """ signature = str(_canon(signature_obj)) existing = seen_signatures.get(signature) if existing: return existing node_id = new_id() seen_signatures[signature] = node_id node_data["id"] = node_id if task_formatter is not None: node_data["task"] = task_formatter(node_id) target_list.append(node_data) return node_id def compile_node(node: JSONValue) -> str | JSONObject: if not isinstance(node, dict): return plan_error( "Invalid plan node.", "Each node must be an object.", ) node_type = str(_get_field(node, "type", "kind") or "").strip().lower() # Be forgiving: infer node type when omitted but structure is unambiguous. if not node_type: if (_get_field(node, "operator", "op") is not None) and ( _get_field(node, "left") is not None or _get_field(node, "right") is not None or _get_field(node, "inputs") is not None ): node_type = "combine" elif _get_field(node, "task", "text") is not None: node_type = "task" if not node_type and node.get("id") is not None: # Model attempted an id-only reference. # Since we ignore ids, this is invalid. return plan_error( "Invalid plan node.", ( "Do not use id-only references. " "Provide a full node object with 'type'." ), ) if node_type in ("combine", "op", "operator"): op_raw = _get_field(node, "operator", "op") operator = _op_value(op_raw) if not operator: return plan_error( "Invalid combine operator.", "Combine node requires a valid operator.", nodeId=node.get("id"), operator=op_raw, ) inputs_raw = _get_field(node, "inputs") left = _get_field(node, "left") right = _get_field(node, "right") if inputs_raw is not None: if not isinstance(inputs_raw, list) or len(inputs_raw) != 2: return plan_error( "Invalid combine inputs.", "Combine node inputs must be a list of exactly 2 child nodes.", nodeId=node.get("id"), ) left_node, right_node = inputs_raw[0], inputs_raw[1] else: left_node, right_node = left, right if left_node is None or right_node is None: return plan_error( "Invalid combine inputs.", "Combine node requires left and right child nodes.", nodeId=node.get("id"), ) dep_ids, error = _compile_dependencies(left_node, right_node) if error is not None: return error left_id, right_id = dep_ids[0], dep_ids[1] display_name = _get_field(node, "display_name", "displayName") instructions = _get_field(node, "instructions") combine_depends_json: JSONArray = [left_id, right_id] combine_node_data: JSONObject = { "kind": "combine", "operator": operator, "inputs": [left_id, right_id], "depends_on": combine_depends_json, "display_name": display_name, "instructions": instructions, "task": display_name or "", # Will be formatted with node_id if needed } combine_signature_obj: JSONObject = { "kind": "combine", "operator": operator, "inputs": [left_id, right_id], "display_name": display_name, "instructions": instructions, } def task_formatter(nid: str) -> str: display_str = str(display_name) if display_name is not None else "" return display_str or f"Combine {nid} ({operator})" return _get_or_create_node_id( combine_signature_obj, combines, combine_node_data, task_formatter ) if node_type in ("task", "step", "subtask"): task_text = str(_get_field(node, "task", "text") or "").strip() if not task_text: return plan_error( "Invalid task node.", "Task node requires a non-empty 'task' string.", nodeId=node.get("id"), ) instructions = str(_get_field(node, "instructions") or "").strip() # Optional per-task context that will be passed to the sub-kani as # additional structured context (e.g. organism, recordType, dataset ids, # constraints). Allow common aliases since models vary: # context / parameters / params. context = _get_field(node, "context", "parameters", "params") if context is not None and not isinstance( context, (dict, list, str, int, float, bool) ): return plan_error( "Invalid task context.", ( "Task node 'context' must be a JSON-serializable " "object/array/string/primitive." ), nodeId=node.get("id"), contextType=type(context).__name__, ) input_node = _get_field(node, "input") dep_ids, error = ( _compile_dependencies(input_node) if input_node is not None else ([], None) ) if error is not None: return error # Convert list[str] to JSONArray (list[JSONValue]) for type compatibility from typing import cast task_depends_json: JSONArray = cast(JSONArray, dep_ids) task_node_data: JSONObject = { "kind": "task", "task": task_text, "instructions": instructions, "context": context, "depends_on": task_depends_json, } task_signature_obj: JSONObject = { "kind": "task", "task": task_text, "instructions": instructions, "context": context, "depends_on": task_depends_json, } return _get_or_create_node_id(task_signature_obj, tasks, task_node_data) return plan_error( "Invalid node type.", "Node 'type' must be either 'task' or 'combine'.", nodeId=node.get("id"), nodeType=node_type, ) root_id = compile_node(plan) if isinstance(root_id, dict): return root_id # Build nodes_by_id with explicit kinds; validate DAG. def _build_nodes_by_id(node_list: JSONArray) -> dict[str, JSONObject]: """Extract nodes from list into nodes_by_id dict. :param node_list: List of nodes. """ result: dict[str, JSONObject] = {} for node in node_list: if isinstance(node, dict): node_obj = as_json_object(node) node_id = node_obj.get("id") if isinstance(node_id, str): result[node_id] = node_obj return result nodes_by_id: dict[str, JSONObject] = {} nodes_by_id.update(_build_nodes_by_id(tasks)) nodes_by_id.update(_build_nodes_by_id(combines)) all_ids = set(nodes_by_id.keys()) if root_id not in all_ids: return plan_error( "Invalid root node.", "Root id missing after compilation.", rootId=root_id, ) indegree: dict[str, int] = dict.fromkeys(all_ids, 0) dependents: dict[str, list[str]] = {node_id: [] for node_id in all_ids} for node_id, node in nodes_by_id.items(): depends_on = node.get("depends_on") if isinstance(depends_on, list): for dep in depends_on: if isinstance(dep, str) and dep in all_ids: indegree[node_id] += 1 dependents[dep].append(node_id) queue = [node_id for node_id, count in indegree.items() if count == 0] processed = 0 pending = dict(indegree) while queue: current = queue.pop() processed += 1 for child in dependents.get(current, []): pending[child] -= 1 if pending[child] == 0: queue.append(child) if processed != len(all_ids): return plan_error( "Dependency cycle detected.", "Cycle detected in delegation graph (tasks/combines). Replan and retry.", ) return DelegationPlan( goal=goal, tasks=tasks, combines=combines, nodes_by_id=nodes_by_id, dependents=dependents, )