"""K-fold cross-validation for overfitting detection.
Splits positive and negative control gene lists into k folds,
evaluates each held-out fold, and aggregates metrics to detect
overfitting.
"""
import math
import operator
import random
from collections.abc import Callable, Coroutine
from typing import Any
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.control_tests import run_positive_negative_controls
from veupath_chatbot.services.experiment.metrics import (
compute_confusion_matrix,
compute_metrics,
metrics_from_control_result,
)
from veupath_chatbot.services.experiment.types import (
ConfusionMatrix,
ControlValueFormat,
CrossValidationResult,
ExperimentMetrics,
FoldMetrics,
)
logger = get_logger(__name__)
ProgressCallback = Callable[[int, int], Coroutine[Any, Any, None]]
"""Async callback(fold_index, total_folds) for progress reporting."""
def _stratified_kfold(ids: list[str], k: int, seed: int = 42) -> list[list[str]]:
"""Split a list into k roughly equal folds (deterministic)."""
shuffled = list(ids)
rng = random.Random(seed)
rng.shuffle(shuffled)
folds: list[list[str]] = [[] for _ in range(k)]
for i, item in enumerate(shuffled):
folds[i % k].append(item)
return folds
def _average_metrics(fold_metrics_list: list[ExperimentMetrics]) -> ExperimentMetrics:
"""Compute element-wise mean of a list of metrics."""
n = len(fold_metrics_list)
if n == 0:
cm = ConfusionMatrix(0, 0, 0, 0)
return compute_metrics(cm)
def _mean(getter: Callable[[ExperimentMetrics], float]) -> float:
return sum(getter(m) for m in fold_metrics_list) / n
cm = ConfusionMatrix(
true_positives=round(_mean(lambda m: m.confusion_matrix.true_positives)),
false_positives=round(_mean(lambda m: m.confusion_matrix.false_positives)),
true_negatives=round(_mean(lambda m: m.confusion_matrix.true_negatives)),
false_negatives=round(_mean(lambda m: m.confusion_matrix.false_negatives)),
)
return ExperimentMetrics(
confusion_matrix=cm,
sensitivity=_mean(lambda m: m.sensitivity),
specificity=_mean(lambda m: m.specificity),
precision=_mean(lambda m: m.precision),
negative_predictive_value=_mean(lambda m: m.negative_predictive_value),
false_positive_rate=_mean(lambda m: m.false_positive_rate),
false_negative_rate=_mean(lambda m: m.false_negative_rate),
f1_score=_mean(lambda m: m.f1_score),
mcc=_mean(lambda m: m.mcc),
balanced_accuracy=_mean(lambda m: m.balanced_accuracy),
youdens_j=_mean(lambda m: m.youdens_j),
total_results=round(_mean(lambda m: m.total_results)),
total_positives=round(_mean(lambda m: m.total_positives)),
total_negatives=round(_mean(lambda m: m.total_negatives)),
)
def _std_metrics(
fold_metrics_list: list[ExperimentMetrics],
mean: ExperimentMetrics,
) -> dict[str, float]:
"""Compute std deviation for key metrics across folds."""
n = len(fold_metrics_list)
if n < 2:
return {}
fields: list[tuple[str, Callable[[ExperimentMetrics], float]]] = [
("sensitivity", operator.attrgetter("sensitivity")),
("specificity", operator.attrgetter("specificity")),
("precision", operator.attrgetter("precision")),
("f1Score", operator.attrgetter("f1_score")),
("mcc", operator.attrgetter("mcc")),
("balancedAccuracy", operator.attrgetter("balanced_accuracy")),
]
result: dict[str, float] = {}
for name, getter in fields:
mean_val = getter(mean)
variance = sum((getter(m) - mean_val) ** 2 for m in fold_metrics_list) / (n - 1)
result[name] = math.sqrt(variance)
return result
def _compute_overfitting_score(
full_metrics: ExperimentMetrics,
mean_holdout: ExperimentMetrics,
) -> tuple[float, str]:
"""Estimate overfitting from the gap between full-set and holdout metrics.
:returns: (score, level) where score is 0-1 and level is low/moderate/high.
"""
f1_gap = abs(full_metrics.f1_score - mean_holdout.f1_score)
sens_gap = abs(full_metrics.sensitivity - mean_holdout.sensitivity)
spec_gap = abs(full_metrics.specificity - mean_holdout.specificity)
score = (f1_gap + sens_gap + spec_gap) / 3.0
if score < 0.1:
level = "low"
elif score < 0.25:
level = "moderate"
else:
level = "high"
return score, level
FoldEvaluator = Callable[
[list[str] | None, list[str] | None],
Coroutine[Any, Any, JSONObject],
]
"""Async callback(holdout_pos, holdout_neg) → control-test result dict."""
async def _run_kfold(
*,
positive_controls: list[str],
negative_controls: list[str],
evaluator: FoldEvaluator,
k: int = 5,
full_metrics: ExperimentMetrics | None = None,
progress_callback: ProgressCallback | None = None,
) -> CrossValidationResult:
"""Shared k-fold cross-validation loop.
:param evaluator: Async callable that evaluates one fold's held-out controls.
:param k: Number of folds.
:param full_metrics: Pre-computed full-set metrics (for overfitting comparison).
:param progress_callback: Optional progress reporter.
:returns: Cross-validation result with per-fold and aggregate metrics.
"""
# k must be at most the size of the smallest non-empty control set,
# but at least 2. If either set has fewer than 2 items, skip it
# in the min() so we can still cross-validate the other set.
size_caps = [len(s) for s in (positive_controls, negative_controls) if len(s) >= 2]
k = max(2, min(k, *size_caps)) if size_caps else 2
pos_folds = _stratified_kfold(positive_controls, k)
neg_folds = _stratified_kfold(negative_controls, k)
fold_results: list[FoldMetrics] = []
for fold_idx in range(k):
holdout_pos = pos_folds[fold_idx]
holdout_neg = neg_folds[fold_idx]
if progress_callback:
await progress_callback(fold_idx, k)
try:
result = await evaluator(
holdout_pos if holdout_pos else None,
holdout_neg if holdout_neg else None,
)
fold_metrics = metrics_from_control_result(result)
except Exception as exc:
logger.warning("Fold %d failed: %s", fold_idx, exc)
cm = compute_confusion_matrix(
positive_hits=0,
total_positives=len(holdout_pos),
negative_hits=0,
total_negatives=len(holdout_neg),
)
fold_metrics = compute_metrics(cm)
fold_results.append(
FoldMetrics(
fold_index=fold_idx,
metrics=fold_metrics,
positive_control_ids=holdout_pos,
negative_control_ids=holdout_neg,
)
)
metrics_list = [f.metrics for f in fold_results]
mean = _average_metrics(metrics_list)
std = _std_metrics(metrics_list, mean)
if full_metrics is not None:
ov_score, ov_level = _compute_overfitting_score(full_metrics, mean)
else:
ov_score, ov_level = 0.0, "low"
return CrossValidationResult(
k=k,
folds=fold_results,
mean_metrics=mean,
std_metrics=std,
overfitting_score=ov_score,
overfitting_level=ov_level,
)
[docs]
async def run_cross_validation(
*,
site_id: str,
record_type: str,
controls_search_name: str,
controls_param_name: str,
positive_controls: list[str],
negative_controls: list[str],
controls_value_format: ControlValueFormat = "newline",
# Single-step params (required when tree is None):
search_name: str | None = None,
parameters: JSONObject | None = None,
# Tree-mode param (when provided, uses tree evaluation):
tree: JSONObject | None = None,
k: int = 5,
full_metrics: ExperimentMetrics | None = None,
progress_callback: ProgressCallback | None = None,
) -> CrossValidationResult:
"""Run k-fold cross-validation on control gene lists.
When *tree* is provided, evaluates each fold against the full strategy
tree. Otherwise, evaluates using the single-step *search_name* +
*parameters*.
"""
evaluator: FoldEvaluator
if tree is not None:
from veupath_chatbot.services.experiment.step_analysis import (
run_controls_against_tree,
)
async def _evaluate_tree(
pos: list[str] | None, neg: list[str] | None
) -> JSONObject:
return await run_controls_against_tree(
site_id=site_id,
record_type=record_type,
tree=tree,
controls_search_name=controls_search_name,
controls_param_name=controls_param_name,
controls_value_format=controls_value_format,
positive_controls=pos,
negative_controls=neg,
)
evaluator = _evaluate_tree
else:
if search_name is None or parameters is None:
raise ValueError(
"search_name and parameters are required for single-step cross-validation"
)
# Bind to locals so the closure captures the non-None values.
_search_name = search_name
_parameters = parameters
async def _evaluate_single(
pos: list[str] | None, neg: list[str] | None
) -> JSONObject:
return await run_positive_negative_controls(
site_id=site_id,
record_type=record_type,
target_search_name=_search_name,
target_parameters=_parameters,
controls_search_name=controls_search_name,
controls_param_name=controls_param_name,
positive_controls=pos,
negative_controls=neg,
controls_value_format=controls_value_format,
)
evaluator = _evaluate_single
return await _run_kfold(
positive_controls=positive_controls,
negative_controls=negative_controls,
evaluator=evaluator,
k=k,
full_metrics=full_metrics,
progress_callback=progress_callback,
)