Source code for veupath_chatbot.domain.strategy.session

"""Stateful strategy session types (in-memory).

These types model the *working* state while a user (or an AI agent) is building a
VEuPathDB strategy during a chat session.
"""

from uuid import uuid4

from veupath_chatbot.domain.strategy.ast import (
    PlanStepNode,
    StrategyAST,
    from_dict,
    parse_analyses,
    parse_colocation_params,
    parse_filters,
    parse_reports,
)
from veupath_chatbot.domain.strategy.ops import CombineOp
from veupath_chatbot.platform.types import (
    JSONArray,
    JSONObject,
    as_json_object,
)

Step = PlanStepNode


[docs] class StrategyGraph: """State for a single strategy graph."""
[docs] def __init__(self, graph_id: str, name: str, site_id: str) -> None: self.id = graph_id self.name = name self.site_id = site_id # Best-effort record type context for the working graph (e.g. "gene"). # Set when the first step is created or when importing a WDK strategy. self.record_type: str | None = None self.current_strategy: StrategyAST | None = None self.steps: dict[str, Step] = {} # Current subtree root IDs. Every step creation updates this set: # the new step is added as a root and any inputs it consumes are # removed. A complete strategy has exactly one root. self.roots: set[str] = set() self.history: list[JSONObject] = [] self.last_step_id: str | None = None # Populated after build_strategy / compile — maps local step IDs to WDK IDs. self.wdk_step_ids: dict[str, int] = {} # Populated after build_strategy — maps local step IDs to estimatedSize. self.step_counts: dict[str, int | None] = {} # WDK strategy ID, set after build_strategy creates the strategy on WDK. self.wdk_strategy_id: int | None = None
[docs] def invalidate_build(self) -> None: """Clear WDK build state so stale counts are not shown. Call after any mutation that changes step semantics (parameters, search_name, operator, delete). The next ``build_strategy`` call will re-populate ``step_counts`` and ``wdk_step_ids``. """ self.step_counts.clear() self.wdk_step_ids.clear() self.wdk_strategy_id = None self.current_strategy = None
[docs] def add_step(self, step: Step) -> str: """Add a step and maintain the subtree-root set. The new step becomes a root. If it consumes existing roots as ``primary_input`` or ``secondary_input``, those are removed from the root set (they are now internal nodes of the new step's subtree). :param step: Step to add. :returns: Step ID. """ self.steps[step.id] = step # The new step is always a root of its subtree. self.roots.add(step.id) # Inputs consumed by this step are no longer roots. if step.primary_input and step.primary_input.id in self.roots: self.roots.discard(step.primary_input.id) if step.secondary_input and step.secondary_input.id in self.roots: self.roots.discard(step.secondary_input.id) self.last_step_id = step.id return step.id
[docs] def get_step(self, step_id: str) -> Step | None: """Get a step by ID. :param step_id: Step ID. :returns: Step or None. """ return self.steps.get(step_id)
[docs] def recompute_roots(self) -> None: """Recompute ``roots`` from the current ``steps`` dict. A root is any step that is not referenced as the ``primary_input`` or ``secondary_input`` of another step. Call this after bulk mutations (delete, hydration) where incremental root tracking is impractical. """ referenced: set[str] = set() for step in self.steps.values(): primary = getattr(getattr(step, "primary_input", None), "id", None) secondary = getattr(getattr(step, "secondary_input", None), "id", None) if isinstance(primary, str) and primary: referenced.add(primary) if isinstance(secondary, str) and secondary: referenced.add(secondary) self.roots = {sid for sid in self.steps if sid not in referenced}
[docs] def save_history(self, description: str) -> None: """Save current state to history. :param description: Description of the state. """ if self.current_strategy: self.history.append( { "description": description, "strategy": self.current_strategy.to_dict(), } )
[docs] def undo(self) -> bool: """Undo to previous state. Restores ``current_strategy`` **and** the derived graph state (``steps``, ``roots``, ``last_step_id``) so that tools that inspect the step graph see a consistent picture after undo. """ if len(self.history) < 2: return False self.history.pop() # remove current previous = self.history[-1] strategy_value = previous.get("strategy") if isinstance(strategy_value, dict): self.current_strategy = from_dict(as_json_object(strategy_value)) self.steps = {s.id: s for s in self.current_strategy.get_all_steps()} self.recompute_roots() self.last_step_id = self.current_strategy.root.id return True
[docs] class StrategySession: """Session context for the active strategy (graph + chat)."""
[docs] def __init__(self, site_id: str) -> None: self.site_id = site_id self.graph: StrategyGraph | None = None
[docs] def add_graph(self, graph: StrategyGraph) -> None: """Register an existing graph in the session. :param graph: Strategy graph to register. """ if self.graph and self.graph.id != graph.id: return self.graph = graph
[docs] def create_graph(self, name: str, graph_id: str | None = None) -> StrategyGraph: """Create a new empty graph and register it. :param name: Graph name. :param graph_id: Optional graph ID (default: None). :returns: The graph. """ if self.graph: if name and name != self.graph.name: self.graph.name = name return self.graph new_id = graph_id or str(uuid4()) graph = StrategyGraph(new_id, name, self.site_id) self.graph = graph return graph
[docs] def get_graph(self, graph_id: str | None) -> StrategyGraph | None: """Get graph by ID (or active graph if None). :param graph_id: Graph ID, or None for active graph. :returns: Graph or None. """ if not self.graph: return None if graph_id is None or graph_id == self.graph.id: return self.graph return None
[docs] def hydrate_graph_from_steps_data( graph: StrategyGraph, steps_data: JSONArray | object, *, root_step_id: str | None = None, record_type: str | None = None, ) -> None: """Hydrate an in-memory graph from persisted flat steps. This is used when we have a persisted `steps` list (and maybe `root_step_id`) but no canonical `plan` to parse into an AST. It enables tools like `list_current_steps` to reflect existing UI-visible nodes. Accepts arbitrary input; non-list values are silently ignored. :param graph: Strategy graph to hydrate. :param steps_data: Flat steps list from persistence (or any value). :param root_step_id: Root step ID (default: None). :param record_type: Record type (default: None). """ if not steps_data or not isinstance(steps_data, list): return nodes: dict[str, PlanStepNode] = {} for step in steps_data: if not isinstance(step, dict): continue step_id = step.get("id") if step_id is None: continue step_id = str(step_id) if not step_id: continue kind = str(step.get("kind") or "").strip().lower() search_name = step.get("searchName") if not isinstance(search_name, str) or not search_name: search_name = "__combine__" if kind == "combine" else "__unknown__" parameters_raw = step.get("parameters") parameters: JSONObject = ( parameters_raw if isinstance(parameters_raw, dict) else {} ) display_name = step.get("displayName") if not isinstance(display_name, str) or not display_name.strip(): display_name = search_name node = PlanStepNode( search_name=search_name, parameters=parameters, display_name=display_name, id=step_id, ) node.filters = parse_filters(step.get("filters")) node.analyses = parse_analyses(step.get("analyses")) node.reports = parse_reports(step.get("reports")) op_raw = step.get("operator") if isinstance(op_raw, str) and op_raw: try: node.operator = CombineOp(op_raw) except Exception: node.operator = None node.colocation_params = parse_colocation_params(step.get("colocationParams")) nodes[step_id] = node # Second pass: connect inputs. for step in steps_data: if not isinstance(step, dict): continue step_id = step.get("id") if step_id is None: continue current_node: PlanStepNode | None = nodes.get(str(step_id)) if current_node is None: continue primary_id = step.get("primaryInputStepId") secondary_id = step.get("secondaryInputStepId") if primary_id is not None: primary_node = nodes.get(str(primary_id)) if primary_node is not None: current_node.primary_input = primary_node if secondary_id is not None: secondary_node = nodes.get(str(secondary_id)) if secondary_node is not None: current_node.secondary_input = secondary_node # Attach hydrated nodes to the graph (don't blow away any already-loaded plan steps). if not graph.steps: graph.steps = nodes else: for sid, node in nodes.items(): graph.steps.setdefault(sid, node) # Best-effort record type context. if record_type and not graph.record_type: graph.record_type = record_type if not graph.record_type: for step in steps_data: if isinstance(step, dict) and step.get("recordType"): graph.record_type = str(step.get("recordType")) break # Restore WDK build state from persisted per-step fields. for step in steps_data: if not isinstance(step, dict): continue sid_raw = step.get("id") if sid_raw is None: continue sid = str(sid_raw) wdk_step_id = step.get("wdkStepId") if isinstance(wdk_step_id, int): graph.wdk_step_ids[sid] = wdk_step_id result_count = step.get("resultCount") if isinstance(result_count, int): graph.step_counts[sid] = result_count # Recompute the subtree-root set from the hydrated step graph. graph.recompute_roots() # Best-effort last-step pointer (used for plan emission when roots is ambiguous). if root_step_id and str(root_step_id) in graph.steps: graph.last_step_id = str(root_step_id) elif len(graph.roots) == 1: graph.last_step_id = next(iter(graph.roots)) elif not graph.last_step_id and graph.roots: # Pick an arbitrary root when multiple exist. graph.last_step_id = next(iter(graph.roots))