Source code for veupath_chatbot.domain.strategy.tree

"""Shared tree walkers for strategy trees.

Two families of trees appear throughout the codebase:

1. **Dict-based trees** -- raw WDK ``stepTree`` payloads or ``PlanStepNode.to_dict()``
   output.  Children live under ``"primaryInput"`` and ``"secondaryInput"`` keys.

2. **AST trees** -- :class:`PlanStepNode` objects with ``.primary_input`` /
   ``.secondary_input`` attributes.

This module provides generic, reusable walkers for both so that every call
site does not need to re-implement the recursive descent.
"""

from collections.abc import Callable

from veupath_chatbot.domain.strategy.ast import PlanStepNode
from veupath_chatbot.platform.types import JSONObject

# ── Dict-based tree walkers ───────────────────────────────────────────


[docs] def walk_dict_tree(root: object, visitor: Callable[[JSONObject], None]) -> None: """Pre-order walk of a dict-based step tree. Calls *visitor* on every node (the node dict itself), then recurses into ``primaryInput`` and ``secondaryInput`` when they are dicts. No-ops silently if *root* is not a dict. """ if not isinstance(root, dict): return visitor(root) pi = root.get("primaryInput") if isinstance(pi, dict): walk_dict_tree(pi, visitor) si = root.get("secondaryInput") if isinstance(si, dict): walk_dict_tree(si, visitor)
[docs] def collect_dict_nodes(root: object) -> list[JSONObject]: """Collect all nodes in a dict-based tree (pre-order).""" nodes: list[JSONObject] = [] walk_dict_tree(root, nodes.append) return nodes
[docs] def collect_dict_leaves(root: object) -> list[JSONObject]: """Collect leaf nodes (no primaryInput or secondaryInput) from a dict tree.""" leaves: list[JSONObject] = [] def _visit(node: JSONObject) -> None: pi = node.get("primaryInput") si = node.get("secondaryInput") if not isinstance(pi, dict) and not isinstance(si, dict): leaves.append(node) walk_dict_tree(root, _visit) return leaves
[docs] def collect_dict_combine_nodes(root: object) -> list[JSONObject]: """Collect combine (binary) nodes from a dict tree. A combine node has both ``primaryInput`` and ``secondaryInput`` as dicts. """ combines: list[JSONObject] = [] def _visit(node: JSONObject) -> None: pi = node.get("primaryInput") si = node.get("secondaryInput") if isinstance(pi, dict) and isinstance(si, dict): combines.append(node) walk_dict_tree(root, _visit) return combines
[docs] def count_dict_nodes(root: object) -> int: """Count all nodes in a dict-based tree.""" if not isinstance(root, dict): return 0 count = 1 pi = root.get("primaryInput") if isinstance(pi, dict): count += count_dict_nodes(pi) si = root.get("secondaryInput") if isinstance(si, dict): count += count_dict_nodes(si) return count
[docs] def map_dict_tree( root: JSONObject, transform: Callable[[JSONObject], JSONObject], ) -> JSONObject: """Bottom-up map: apply *transform* to every node, children first. Children are replaced with their transformed versions before the parent is passed to *transform*. The original tree is **not** mutated; each node dict is shallow-copied before transformation. """ import copy node = copy.copy(root) pi = node.get("primaryInput") if isinstance(pi, dict): node["primaryInput"] = map_dict_tree(pi, transform) si = node.get("secondaryInput") if isinstance(si, dict): node["secondaryInput"] = map_dict_tree(si, transform) return transform(node)
# ── PlanStepNode (AST) walkers ────────────────────────────────────────
[docs] def walk_plan_tree(root: PlanStepNode, visitor: Callable[[PlanStepNode], None]) -> None: """Pre-order walk of a PlanStepNode AST. Calls *visitor* on every node, then recurses into ``.primary_input`` and ``.secondary_input``. """ visitor(root) if root.primary_input is not None: walk_plan_tree(root.primary_input, visitor) if root.secondary_input is not None: walk_plan_tree(root.secondary_input, visitor)
[docs] def collect_plan_nodes(root: PlanStepNode) -> list[PlanStepNode]: """Collect all AST nodes (pre-order).""" nodes: list[PlanStepNode] = [] walk_plan_tree(root, nodes.append) return nodes
[docs] def collect_plan_leaves(root: PlanStepNode) -> list[PlanStepNode]: """Collect leaf AST nodes (no primary or secondary input).""" leaves: list[PlanStepNode] = [] def _visit(node: PlanStepNode) -> None: if node.primary_input is None and node.secondary_input is None: leaves.append(node) walk_plan_tree(root, _visit) return leaves