Source code for veupath_chatbot.services.experiment.enrichment_compare
"""Cross-experiment enrichment comparison."""
from typing import TypedDict
from veupath_chatbot.platform.types import JSONValue
from veupath_chatbot.services.experiment.types import Experiment
[docs]
class EnrichmentRow(TypedDict):
"""Shape of one term row in the enrichment comparison."""
termKey: str
termName: str
analysisType: str
scores: dict[str, JSONValue]
maxScore: float
experimentCount: int
[docs]
class EnrichmentCompareResult(TypedDict):
"""Return shape of :func:`compare_enrichment_across`."""
experimentIds: list[str]
experimentLabels: dict[str, str]
rows: list[EnrichmentRow]
totalTerms: int
[docs]
def compare_enrichment_across(
experiments: list[Experiment],
experiment_ids: list[str],
analysis_type: str | None = None,
) -> EnrichmentCompareResult:
"""Compare enrichment results across experiments.
Builds a term-by-experiment matrix of fold-enrichment scores.
Optionally filters to a single analysis type.
"""
labels: dict[str, str] = {
exp.id: (exp.config.name or exp.id) for exp in experiments
}
# Collect scores: term_key -> { experiment_id -> fold_enrichment }
# Also track term metadata (name, analysis type)
term_scores: dict[str, dict[str, float]] = {}
term_meta: dict[str, tuple[str, str]] = {}
for exp in experiments:
for er in exp.enrichment_results:
if analysis_type and er.analysis_type != analysis_type:
continue
for term in er.terms:
key = f"{er.analysis_type}:{term.term_id}"
if key not in term_meta:
term_meta[key] = (term.term_name, er.analysis_type)
term_scores[key] = {}
term_scores[key][exp.id] = term.fold_enrichment
# Build rows sorted by max score descending
rows: list[EnrichmentRow] = []
for key in sorted(
term_scores,
key=lambda k: max(term_scores[k].values()) if term_scores[k] else 0.0,
reverse=True,
):
name, a_type = term_meta[key]
scores_map = term_scores[key]
scores_for_row: dict[str, JSONValue] = {
eid: round(scores_map[eid], 4) if eid in scores_map else None
for eid in experiment_ids
}
max_score = max(scores_map.values()) if scores_map else 0.0
rows.append(
{
"termKey": key,
"termName": name,
"analysisType": a_type,
"scores": scores_for_row,
"maxScore": round(max_score, 4),
"experimentCount": len(scores_map),
}
)
return {
"experimentIds": list(experiment_ids),
"experimentLabels": labels,
"rows": rows,
"totalTerms": len(rows),
}