Source code for veupath_chatbot.services.experiment.step_analysis.phase_operators

"""Phase 2: Operator comparison -- try all operators at each combine node."""

import asyncio

from veupath_chatbot.domain.strategy.ops import CombineOp
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.experiment.helpers import ProgressCallback
from veupath_chatbot.services.experiment.step_analysis._evaluation import (
    _extract_eval_counts,
    _f1_from_counts,
    run_controls_against_tree,
)
from veupath_chatbot.services.experiment.step_analysis._tree_utils import (
    _build_subtree_with_operator,
    _collect_combine_nodes,
    _node_id,
)
from veupath_chatbot.services.experiment.types import (
    ControlValueFormat,
    OperatorComparison,
    OperatorVariant,
    to_json,
)

COMPARISON_OPERATORS = [
    op.value
    for op in CombineOp
    if op not in (CombineOp.COLOCATE, CombineOp.LONLY, CombineOp.RONLY)
]

logger = get_logger(__name__)


[docs] async def compare_operators( *, site_id: str, record_type: str, tree: JSONObject, controls_search_name: str, controls_param_name: str, controls_value_format: ControlValueFormat, positive_controls: list[str], negative_controls: list[str], progress_callback: ProgressCallback | None = None, ) -> list[OperatorComparison]: """For each combine node, evaluate INTERSECT, UNION, MINUS and recommend. :param tree: ``PlanStepNode``-shaped dict. :returns: One :class:`OperatorComparison` per combine node. """ combine_nodes = _collect_combine_nodes(tree) if not combine_nodes: return [] results: list[OperatorComparison] = [] sem = asyncio.Semaphore(2) for ci, cnode in enumerate(combine_nodes): cid = _node_id(cnode) current_op = str(cnode.get("operator", "INTERSECT")) if progress_callback: await progress_callback( { "type": "step_analysis_progress", "data": { "phase": "operator_comparison", "message": f"Comparing operators at node {ci + 1}/{len(combine_nodes)}", "current": ci + 1, "total": len(combine_nodes), }, } ) variants: list[OperatorVariant] = [] async def _try_operator( op: str, _cnode: JSONObject = cnode, _cid: str = cid, ) -> OperatorVariant | None: subtree = _build_subtree_with_operator(_cnode, op) try: async with sem: raw = await run_controls_against_tree( site_id=site_id, record_type=record_type, tree=subtree, controls_search_name=controls_search_name, controls_param_name=controls_param_name, controls_value_format=controls_value_format, positive_controls=positive_controls, negative_controls=negative_controls, ) except Exception as exc: logger.warning( "Operator comparison failed", node=_cid, op=op, error=str(exc) ) return None counts = _extract_eval_counts(raw) recall = counts.pos_hits / counts.pos_total if counts.pos_total > 0 else 0.0 fpr = counts.neg_hits / counts.neg_total if counts.neg_total > 0 else 0.0 return OperatorVariant( operator=op, positive_hits=counts.pos_hits, negative_hits=counts.neg_hits, total_results=counts.total_results, recall=recall, false_positive_rate=fpr, f1_score=_f1_from_counts(counts), ) op_tasks = [_try_operator(op) for op in COMPARISON_OPERATORS] op_results = await asyncio.gather(*op_tasks) variants = [v for v in op_results if v is not None] best = max(variants, key=lambda v: v.f1_score) if variants else None recommendation = "" recommended_op = current_op if best and best.operator != current_op: cur = next((v for v in variants if v.operator == current_op), None) if cur and best.f1_score > cur.f1_score + 0.01: recommended_op = best.operator recommendation = ( f"Switching from {current_op} to {best.operator} " f"improves F1 from {cur.f1_score:.2f} to {best.f1_score:.2f} " f"(recall {cur.recall:.0%} -> {best.recall:.0%}, " f"FPR {cur.false_positive_rate:.0%} -> {best.false_positive_rate:.0%})" ) else: recommendation = ( f"Current operator {current_op} is already optimal or near-optimal." ) elif best: recommendation = f"Current operator {current_op} is already optimal." oc = OperatorComparison( combine_node_id=cid, current_operator=current_op, variants=variants, recommendation=recommendation, recommended_operator=recommended_op, ) results.append(oc) if progress_callback: await progress_callback( { "type": "step_analysis_progress", "data": { "phase": "operator_comparison", "message": ( f"Node {cid}: {recommendation}" if recommendation else f"Node {cid}: compared {len(variants)} operators" ), "current": ci + 1, "total": len(combine_nodes), "operatorComparison": to_json(oc), }, } ) logger.info("Operator comparison complete", count=len(results)) return results