Source code for veupath_chatbot.services.experiment.robustness

"""Bootstrap robustness and uncertainty estimation.

Resamples control sets with replacement and recomputes rank metrics
to derive confidence intervals and stability scores — all pure Python,
no additional WDK API calls required.
"""

import random
from collections import defaultdict

from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.services.experiment.metrics import (
    compute_confusion_matrix,
    compute_metrics,
)
from veupath_chatbot.services.experiment.rank_metrics import compute_rank_metrics
from veupath_chatbot.services.experiment.types import (
    DEFAULT_K_VALUES,
    BootstrapResult,
    ConfidenceInterval,
    NegativeSetVariant,
)

logger = get_logger(__name__)


[docs] def compute_robustness( result_ids: list[str], positive_ids: list[str], negative_ids: list[str], *, n_bootstrap: int = 200, k_values: list[int] | None = None, seed: int = 42, alternative_negatives: dict[str, list[str]] | None = None, include_rank_metrics: bool = True, ) -> BootstrapResult: """Compute bootstrap confidence intervals for classification (and optionally rank) metrics. :param result_ids: Ordered gene IDs from the strategy result. :param positive_ids: Positive control gene IDs. :param negative_ids: Negative control gene IDs. :param n_bootstrap: Number of bootstrap iterations. :param k_values: K values for Precision/Recall/Enrichment@K. :param seed: Random seed for reproducibility. :param alternative_negatives: Optional map of label -> negative IDs for negative-set sensitivity analysis. :param include_rank_metrics: When ``False``, skip rank metric CIs and top-K stability — only classification CIs are computed. :returns: Bootstrap robustness result. """ if k_values is None: k_values = DEFAULT_K_VALUES rng = random.Random(seed) metric_samples: dict[str, list[float]] = defaultdict(list) rank_metric_samples: dict[str, list[float]] = defaultdict(list) top_k_sets: list[set[str]] = [] pos_list = list(positive_ids) neg_list = list(negative_ids) for _ in range(n_bootstrap): boot_pos = _resample(pos_list, rng) boot_neg = _resample(neg_list, rng) if include_rank_metrics: boot_pos_set = set(boot_pos) boot_neg_set = set(boot_neg) rm = compute_rank_metrics( result_ids=result_ids, positive_ids=boot_pos_set, negative_ids=boot_neg_set, k_values=k_values, ) for kv in k_values: rank_metric_samples[f"precision_at_{kv}"].append( rm.precision_at_k.get(kv, 0.0) ) rank_metric_samples[f"recall_at_{kv}"].append( rm.recall_at_k.get(kv, 0.0) ) rank_metric_samples[f"enrichment_at_{kv}"].append( rm.enrichment_at_k.get(kv, 0.0) ) stability_k = 50 # Use the bootstrapped positive set to determine which of the # top-K results are "relevant" — this varies across iterations, # producing a meaningful stability estimate. top_k_ids = result_ids[:stability_k] boot_relevant = {gid for gid in top_k_ids if gid in boot_pos_set} top_k_sets.append(boot_relevant) _collect_classification_metrics( result_ids, set(boot_pos), set(boot_neg), metric_samples ) metric_cis = {k: _ci_from_samples(v) for k, v in metric_samples.items()} rank_metric_cis = {k: _ci_from_samples(v) for k, v in rank_metric_samples.items()} top_k_stability = _mean_jaccard(top_k_sets) if top_k_sets else 0.0 neg_variants: list[NegativeSetVariant] = [] if include_rank_metrics and alternative_negatives: for label, alt_neg in alternative_negatives.items(): rm = compute_rank_metrics( result_ids=result_ids, positive_ids=set(positive_ids), negative_ids=set(alt_neg), k_values=k_values, ) neg_variants.append( NegativeSetVariant( label=label, negative_count=len(alt_neg), rank_metrics=rm, ) ) return BootstrapResult( n_iterations=n_bootstrap, metric_cis=metric_cis, rank_metric_cis=rank_metric_cis, top_k_stability=top_k_stability, negative_set_sensitivity=neg_variants, )
def _resample(items: list[str], rng: random.Random) -> list[str]: """Resample with replacement.""" n = len(items) if n == 0: return [] return [items[rng.randint(0, n - 1)] for _ in range(n)] def _collect_classification_metrics( result_ids: list[str], pos_set: set[str], neg_set: set[str], samples: dict[str, list[float]], ) -> None: """Compute binary classification metrics and accumulate into samples dict.""" result_set = set(result_ids) cm = compute_confusion_matrix( positive_hits=len(pos_set & result_set), total_positives=len(pos_set), negative_hits=len(neg_set & result_set), total_negatives=len(neg_set), ) m = compute_metrics(cm) samples["sensitivity"].append(m.sensitivity) samples["specificity"].append(m.specificity) samples["precision"].append(m.precision) samples["f1_score"].append(m.f1_score) def _ci_from_samples( samples: list[float], alpha: float = 0.05, ) -> ConfidenceInterval: """Compute percentile-based confidence interval.""" if not samples: return ConfidenceInterval(lower=0.0, mean=0.0, upper=0.0, std=0.0) n = len(samples) sorted_s = sorted(samples) lo_idx = max(0, int(n * alpha / 2)) hi_idx = min(n - 1, int(n * (1 - alpha / 2))) mean = sum(sorted_s) / n variance = sum((x - mean) ** 2 for x in sorted_s) / max(n - 1, 1) std = variance**0.5 return ConfidenceInterval( lower=sorted_s[lo_idx], mean=mean, upper=sorted_s[hi_idx], std=std, ) def _mean_jaccard(sets: list[set[str]]) -> float: """Average pairwise Jaccard similarity (sampled for efficiency).""" if len(sets) < 2: return 1.0 n = len(sets) max_pairs = 200 rng = random.Random(0) total = 0.0 count = 0 for _ in range(max_pairs): indices = rng.sample(range(n), 2) i, j = indices[0], indices[1] inter = len(sets[i] & sets[j]) union = len(sets[i] | sets[j]) if union > 0: total += inter / union count += 1 return total / count if count > 0 else 1.0