Source code for veupath_chatbot.services.experiment.step_analysis.phase_step_eval

"""Phase 1: Per-step evaluation -- evaluate each leaf independently."""

import asyncio

from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.control_tests import run_positive_negative_controls
from veupath_chatbot.services.experiment.helpers import ProgressCallback
from veupath_chatbot.services.experiment.step_analysis._evaluation import (
    _extract_eval_counts,
    run_controls_against_tree,
)
from veupath_chatbot.services.experiment.step_analysis._tree_utils import (
    _collect_leaves,
    _extract_leaf_branch,
    _node_id,
)
from veupath_chatbot.services.experiment.types import (
    ControlValueFormat,
    StepEvaluation,
    to_json,
)

logger = get_logger(__name__)


[docs] async def evaluate_steps( *, site_id: str, record_type: str, tree: JSONObject, controls_search_name: str, controls_param_name: str, controls_value_format: ControlValueFormat, positive_controls: list[str], negative_controls: list[str], progress_callback: ProgressCallback | None = None, ) -> list[StepEvaluation]: """Evaluate each leaf step against controls, preserving ancestor transforms. For each leaf, the evaluation includes any transform chain above it (e.g. ``GenesByOrthologs``) so that cross-organism searches are converted before being compared against controls. :param tree: ``PlanStepNode``-shaped dict. :returns: One :class:`StepEvaluation` per leaf. """ leaves = _collect_leaves(tree) if not leaves: return [] results: list[StepEvaluation] = [] sem = asyncio.Semaphore(3) async def _eval_leaf(leaf: JSONObject, idx: int) -> StepEvaluation | None: search_name = str(leaf.get("searchName", "")) if not search_name or search_name.startswith("__"): return None display = str(leaf.get("displayName", search_name)) lid = _node_id(leaf) branch = _extract_leaf_branch(tree, lid) has_transforms = branch is not None and isinstance( branch.get("primaryInput"), dict ) if progress_callback: label = f"{display} (via transforms)" if has_transforms else display await progress_callback( { "type": "step_analysis_progress", "data": { "phase": "step_evaluation", "message": f"Evaluating step {idx + 1}/{len(leaves)}: {label}", "current": idx + 1, "total": len(leaves), }, } ) try: async with sem: if branch is not None and has_transforms: raw = await run_controls_against_tree( site_id=site_id, record_type=record_type, tree=branch, controls_search_name=controls_search_name, controls_param_name=controls_param_name, controls_value_format=controls_value_format, positive_controls=positive_controls, negative_controls=negative_controls, ) else: raw_params = leaf.get("parameters") parameters: JSONObject = ( raw_params if isinstance(raw_params, dict) else {} ) raw = await run_positive_negative_controls( site_id=site_id, record_type=record_type, target_search_name=search_name, target_parameters=parameters, controls_search_name=controls_search_name, controls_param_name=controls_param_name, controls_value_format=controls_value_format, positive_controls=positive_controls, negative_controls=negative_controls, ) except Exception as exc: logger.warning("Step evaluation failed", step=lid, error=str(exc)) return None counts = _extract_eval_counts(raw) recall = counts.pos_hits / counts.pos_total if counts.pos_total > 0 else 0.0 fpr = counts.neg_hits / counts.neg_total if counts.neg_total > 0 else 0.0 ev = StepEvaluation( step_id=lid, search_name=search_name, display_name=display, result_count=counts.total_results, positive_hits=counts.pos_hits, positive_total=counts.pos_total, negative_hits=counts.neg_hits, negative_total=counts.neg_total, recall=recall, false_positive_rate=fpr, captured_positive_ids=counts.pos_ids, captured_negative_ids=counts.neg_ids, ) if progress_callback: await progress_callback( { "type": "step_analysis_progress", "data": { "phase": "step_evaluation", "message": f"Evaluated {display}: recall {recall:.0%}, FPR {fpr:.0%}", "current": idx + 1, "total": len(leaves), "stepEvaluation": to_json(ev), }, } ) return ev tasks = [_eval_leaf(leaf, i) for i, leaf in enumerate(leaves)] evaluations = await asyncio.gather(*tasks) results = [e for e in evaluations if e is not None] logger.info("Step evaluation complete", count=len(results)) return results