"""Main entry point: run_step_analysis coordinates all four analysis phases."""
from collections.abc import Callable
from typing import Any, TypedDict
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.experiment.helpers import ProgressCallback
from veupath_chatbot.services.experiment.step_analysis._evaluation import (
_extract_eval_counts,
)
from veupath_chatbot.services.experiment.step_analysis.phase_contribution import (
analyze_contributions,
)
from veupath_chatbot.services.experiment.step_analysis.phase_operators import (
compare_operators,
)
from veupath_chatbot.services.experiment.step_analysis.phase_sensitivity import (
sweep_parameters,
)
from veupath_chatbot.services.experiment.step_analysis.phase_step_eval import (
evaluate_steps,
)
from veupath_chatbot.services.experiment.types import (
ControlValueFormat,
StepAnalysisResult,
StepContribution,
StepEvaluation,
)
logger = get_logger(__name__)
class _SharedPhaseKwargs(TypedDict):
site_id: str
record_type: str
tree: JSONObject
controls_search_name: str
controls_param_name: str
controls_value_format: ControlValueFormat
positive_controls: list[str]
negative_controls: list[str]
progress_callback: ProgressCallback | None
# ---------------------------------------------------------------------------
# Post-processing: enrich evaluations / contributions with derived fields
# ---------------------------------------------------------------------------
def _enrich_step_evals_with_movement(
evals: list[StepEvaluation],
baseline_result: JSONObject,
) -> list[StepEvaluation]:
"""Add TP/FP/FN movement fields relative to the full-strategy baseline."""
baseline_counts = _extract_eval_counts(baseline_result)
baseline_tp = baseline_counts.pos_hits
baseline_fp = baseline_counts.neg_hits
baseline_fn = baseline_counts.pos_total - baseline_counts.pos_hits
enriched: list[StepEvaluation] = []
for ev in evals:
step_fn = ev.positive_total - ev.positive_hits
enriched.append(
StepEvaluation(
step_id=ev.step_id,
search_name=ev.search_name,
display_name=ev.display_name,
result_count=ev.result_count,
positive_hits=ev.positive_hits,
positive_total=ev.positive_total,
negative_hits=ev.negative_hits,
negative_total=ev.negative_total,
recall=ev.recall,
false_positive_rate=ev.false_positive_rate,
captured_positive_ids=ev.captured_positive_ids,
captured_negative_ids=ev.captured_negative_ids,
tp_movement=ev.positive_hits - baseline_tp,
fp_movement=ev.negative_hits - baseline_fp,
fn_movement=step_fn - baseline_fn,
)
)
return enriched
def _enrich_contributions_with_narrative(
contributions: list[StepContribution],
) -> list[StepContribution]:
"""Generate human-readable narrative text for each contribution."""
enriched: list[StepContribution] = []
for sc in contributions:
parts: list[str] = []
if sc.recall_delta < -0.05:
parts.append(
f"Removing this step drops recall by {abs(sc.recall_delta):.0%}"
)
elif sc.recall_delta > 0.02:
parts.append(f"Removing this step improves recall by {sc.recall_delta:.0%}")
if sc.fpr_delta < -0.05:
parts.append(f"and reduces false positive rate by {abs(sc.fpr_delta):.0%}")
elif sc.fpr_delta > 0.05:
parts.append(f"but increases false positive rate by {sc.fpr_delta:.0%}")
if not parts:
if sc.verdict == "neutral":
narrative = "This step has minimal impact on results."
elif sc.verdict == "essential":
narrative = "This step is critical \u2014 removing it significantly hurts recall."
else:
narrative = f"Verdict: {sc.verdict}."
else:
narrative = ", ".join(parts) + "."
enriched.append(
StepContribution(
step_id=sc.step_id,
search_name=sc.search_name,
baseline_recall=sc.baseline_recall,
ablated_recall=sc.ablated_recall,
recall_delta=sc.recall_delta,
baseline_fpr=sc.baseline_fpr,
ablated_fpr=sc.ablated_fpr,
fpr_delta=sc.fpr_delta,
verdict=sc.verdict,
enrichment_delta=sc.enrichment_delta,
narrative=narrative,
)
)
return enriched
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
[docs]
async def run_step_analysis(
*,
site_id: str,
record_type: str,
tree: JSONObject,
controls_search_name: str,
controls_param_name: str,
controls_value_format: ControlValueFormat,
positive_controls: list[str],
negative_controls: list[str],
baseline_result: JSONObject,
phases: list[str] | None = None,
progress_callback: ProgressCallback | None = None,
) -> StepAnalysisResult:
"""Run all requested step analysis phases.
:param tree: ``PlanStepNode``-shaped dict.
:param baseline_result: Raw result from the initial tree evaluation.
:param phases: Which phases to run. Defaults to all four.
:returns: Aggregated :class:`StepAnalysisResult`.
"""
enabled = set(
phases
or [
"step_evaluation",
"operator_comparison",
"contribution",
"sensitivity",
]
)
# Shared kwargs passed to every phase function.
shared_kwargs: _SharedPhaseKwargs = {
"site_id": site_id,
"record_type": record_type,
"tree": tree,
"controls_search_name": controls_search_name,
"controls_param_name": controls_param_name,
"controls_value_format": controls_value_format,
"positive_controls": positive_controls,
"negative_controls": negative_controls,
"progress_callback": progress_callback,
}
# Phase descriptors: (phase_key, progress_message, async_fn, extra_kwargs)
phase_descriptors: list[
tuple[str, str, Callable[..., Any], dict[str, JSONObject]]
] = [
("step_evaluation", "Starting per-step evaluation...", evaluate_steps, {}),
("operator_comparison", "Comparing operators...", compare_operators, {}),
(
"contribution",
"Running ablation analysis...",
analyze_contributions,
{"baseline_metrics": baseline_result},
),
("sensitivity", "Sweeping parameters...", sweep_parameters, {}),
]
results: dict[str, Any] = {}
for phase_key, message, phase_fn, extra_kwargs in phase_descriptors:
if phase_key not in enabled:
continue
if progress_callback:
await progress_callback(
{
"type": "step_analysis_progress",
"data": {"phase": phase_key, "message": message},
}
)
results[phase_key] = await phase_fn(**shared_kwargs, **extra_kwargs)
# Post-process phases that need enrichment.
step_evals: list[StepEvaluation] = results.get("step_evaluation", [])
if step_evals:
step_evals = _enrich_step_evals_with_movement(step_evals, baseline_result)
contributions: list[StepContribution] = results.get("contribution", [])
if contributions:
contributions = _enrich_contributions_with_narrative(contributions)
return StepAnalysisResult(
step_evaluations=step_evals,
operator_comparisons=results.get("operator_comparison", []),
step_contributions=contributions,
parameter_sensitivities=results.get("sensitivity", []),
)