Source code for veupath_chatbot.services.experiment.materialization

"""WDK strategy materialization for experiments.

Creates, persists, and cleans up WDK strategies from experiment configs,
including step tree materialization for multi-step and import modes.
"""

from veupath_chatbot.domain.strategy.ast import StepTreeNode
from veupath_chatbot.domain.strategy.ops import DEFAULT_COMBINE_OPERATOR
from veupath_chatbot.integrations.veupathdb.factory import get_strategy_api
from veupath_chatbot.integrations.veupathdb.strategy_api import StrategyAPI
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject, as_json_object
from veupath_chatbot.services.experiment.helpers import coerce_step_id, extract_wdk_id
from veupath_chatbot.services.experiment.types import (
    Experiment,
    ExperimentConfig,
)

logger = get_logger(__name__)


async def _materialize_step_tree(
    api: StrategyAPI,
    node: JSONObject,
    record_type: str,
    *,
    site_id: str = "",
) -> StepTreeNode:
    """Recursively create WDK steps from a ``PlanStepNode`` dict.

    Walks the tree bottom-up: leaf search nodes are created first,
    then combine/transform nodes reference them.

    :param api: Strategy API instance.
    :param node: ``PlanStepNode``-shaped dict.
    :param record_type: WDK record type for all steps.
    :param site_id: VEuPathDB site identifier (for param auto-expansion).
    :returns: :class:`StepTreeNode` ready for strategy creation.
    """
    primary_node = node.get("primaryInput")
    secondary_node = node.get("secondaryInput")

    primary_tree: StepTreeNode | None = None
    secondary_tree: StepTreeNode | None = None

    if isinstance(primary_node, dict):
        primary_tree = await _materialize_step_tree(
            api, primary_node, record_type, site_id=site_id
        )
    if isinstance(secondary_node, dict):
        secondary_tree = await _materialize_step_tree(
            api, secondary_node, record_type, site_id=site_id
        )

    search_name = str(node.get("searchName", ""))
    raw_params = node.get("parameters")
    parameters: JSONObject = raw_params if isinstance(raw_params, dict) else {}
    display_name = str(node.get("displayName", search_name))

    if primary_tree is not None and secondary_tree is not None:
        operator = str(node.get("operator", DEFAULT_COMBINE_OPERATOR.value))
        if operator == "COLOCATE":
            # Colocation uses GenesBySpanLogic — two input-step params
            # (span_a, span_b) wired via stepTree at strategy creation.
            coloc_raw = node.get("colocationParams")
            upstream = "0"
            downstream = "0"
            if isinstance(coloc_raw, dict):
                upstream = str(coloc_raw.get("upstream", 0))
                downstream = str(coloc_raw.get("downstream", 0))
            coloc_params: JSONObject = {
                "span_sentence": "sentence",
                "span_operation": "overlap",
                "span_strand": "Both strands",
                "span_output": "a",
                "region_a": "upstream",
                "region_b": "exact",
                "span_begin_a": "start",
                "span_begin_direction_a": "-",
                "span_begin_offset_a": upstream,
                "span_end_a": "start",
                "span_end_direction_a": "-",
                "span_end_offset_a": downstream,
                "span_begin_b": "start",
                "span_begin_direction_b": "-",
                "span_begin_offset_b": "0",
                "span_end_b": "stop",
                "span_end_direction_b": "-",
                "span_end_offset_b": "0",
            }
            step = await api.create_transform_step(
                input_step_id=primary_tree.step_id,
                transform_name="GenesBySpanLogic",
                parameters=coloc_params,
                record_type=record_type,
                custom_name=display_name,
            )
        else:
            step = await api.create_combined_step(
                primary_step_id=primary_tree.step_id,
                secondary_step_id=secondary_tree.step_id,
                boolean_operator=operator,
                record_type=record_type,
                custom_name=display_name,
            )
        step_id = coerce_step_id(step)
        return StepTreeNode(
            step_id, primary_input=primary_tree, secondary_input=secondary_tree
        )
    elif primary_tree is not None:
        step = await api.create_transform_step(
            input_step_id=primary_tree.step_id,
            transform_name=search_name,
            parameters=parameters,
            record_type=record_type,
            custom_name=display_name,
        )
        step_id = coerce_step_id(step)
        return StepTreeNode(step_id, primary_input=primary_tree)
    else:
        step = await api.create_step(
            record_type=record_type,
            search_name=search_name,
            parameters=parameters,
            custom_name=display_name,
        )
        step_id = coerce_step_id(step)
        return StepTreeNode(step_id)


async def _persist_experiment_strategy(
    config: ExperimentConfig,
    experiment_id: str,
    *,
    override_tree: JSONObject | None = None,
) -> JSONObject:
    """Create a persisted WDK strategy for result exploration.

    Handles all experiment modes:

    * **single**: one search step.
    * **multi-step**: recursively materialise the ``step_tree``.
    * **import**: duplicate the step tree from an existing WDK strategy.

    :param config: Experiment configuration.
    :param experiment_id: Unique experiment identifier.
    :param override_tree: If provided, materialise this tree instead of the
        config's ``step_tree`` (used after tree optimisation).
    :returns: Dict with ``strategy_id`` and ``step_id``.
    """
    api = get_strategy_api(config.site_id)
    mode = config.mode or "single"

    if mode == "import" and config.source_strategy_id and override_tree is None:
        return await _persist_import_strategy(api, config, experiment_id)

    effective_tree = override_tree or (
        config.step_tree if isinstance(config.step_tree, dict) else None
    )
    if mode in ("multi-step", "import") and isinstance(effective_tree, dict):
        root_tree = await _materialize_step_tree(
            api, effective_tree, config.record_type, site_id=config.site_id
        )
    else:
        step_payload = await api.create_step(
            record_type=config.record_type,
            search_name=config.search_name,
            parameters=config.parameters or {},
            custom_name=f"Experiment: {config.name}",
        )
        step_id = coerce_step_id(step_payload)
        root_tree = StepTreeNode(step_id)

    created = await api.create_strategy(
        step_tree=root_tree,
        name=f"exp:{experiment_id}",
        description=f"Persisted strategy for experiment {config.name}",
        is_internal=True,
    )
    strategy_id = extract_wdk_id(created)
    if strategy_id is None:
        raise ValueError("Failed to create WDK strategy for experiment")

    logger.info(
        "Persisted WDK strategy for experiment",
        experiment_id=experiment_id,
        strategy_id=strategy_id,
        step_id=root_tree.step_id,
    )
    return {"strategy_id": strategy_id, "step_id": root_tree.step_id}


async def _persist_import_strategy(
    api: StrategyAPI,
    config: ExperimentConfig,
    experiment_id: str,
) -> JSONObject:
    """Import an existing WDK strategy by duplicating its step tree.

    Uses the WDK ``duplicated-step-tree`` endpoint to copy the source
    strategy's step tree into a new set of unattached steps.

    :param api: Strategy API instance.
    :param config: Experiment configuration (must have ``source_strategy_id``).
    :param experiment_id: Unique experiment identifier.
    :returns: Dict with ``strategy_id`` and ``step_id``.
    """
    if not config.source_strategy_id:
        raise ValueError("source_strategy_id is required for import mode")
    source_id = int(config.source_strategy_id)

    # WDK POST .../duplicated-step-tree returns {"stepTree": {...}}
    dup_resp = await api.client.post(
        f"/users/{api.user_id}/strategies/{source_id}/duplicated-step-tree"
    )
    if not isinstance(dup_resp, dict) or "stepTree" not in dup_resp:
        raise ValueError(f"Failed to duplicate step tree from strategy {source_id}")

    raw_tree = as_json_object(dup_resp["stepTree"])

    # The duplicated tree already has real WDK step IDs, so we can
    # directly wrap it in a StepTreeNode.
    def _tree_to_node(t: JSONObject) -> StepTreeNode:
        raw_sid = t["stepId"]
        if isinstance(raw_sid, int):
            sid = raw_sid
        elif isinstance(raw_sid, (str, float)):
            sid = int(raw_sid)
        else:
            raise ValueError(f"Invalid stepId type: {type(raw_sid)}")
        primary = t.get("primaryInput")
        secondary = t.get("secondaryInput")
        return StepTreeNode(
            sid,
            primary_input=_tree_to_node(primary) if isinstance(primary, dict) else None,
            secondary_input=_tree_to_node(secondary)
            if isinstance(secondary, dict)
            else None,
        )

    root = _tree_to_node(raw_tree)

    created = await api.create_strategy(
        step_tree=root,
        name=f"exp:{experiment_id}",
        description=f"Imported strategy for experiment {config.name}",
        is_internal=True,
    )
    strategy_id = extract_wdk_id(created)
    if strategy_id is None:
        raise ValueError("Failed to create WDK strategy from imported tree")

    logger.info(
        "Persisted imported WDK strategy for experiment",
        experiment_id=experiment_id,
        strategy_id=strategy_id,
        step_id=root.step_id,
    )
    return {"strategy_id": strategy_id, "step_id": root.step_id}


[docs] async def cleanup_experiment_strategy(experiment: Experiment) -> None: """Delete the persisted WDK strategy when an experiment is deleted. :param experiment: Experiment whose WDK strategy should be cleaned up. """ if experiment.wdk_strategy_id is None: return try: api = get_strategy_api(experiment.config.site_id) await api.delete_strategy(experiment.wdk_strategy_id) logger.info( "Deleted WDK strategy for experiment", experiment_id=experiment.id, strategy_id=experiment.wdk_strategy_id, ) except Exception as exc: logger.warning( "Failed to delete WDK strategy during experiment cleanup", experiment_id=experiment.id, strategy_id=experiment.wdk_strategy_id, error=str(exc), )