Source code for veupath_chatbot.services.gene_sets.reverse_search
"""Reverse search — rank gene sets by how well they recover positive genes.
Given a set of known-positive gene IDs, score each candidate gene set on
recall, precision, and F1 using pure set intersection. No WDK calls needed
because the gene IDs are already materialised.
"""
from dataclasses import dataclass
[docs]
@dataclass(frozen=True, slots=True)
class GeneSetCandidate:
"""A gene set to evaluate against the positive controls."""
id: str
name: str
gene_ids: list[str]
search_name: str | None = None
[docs]
@dataclass(frozen=True, slots=True)
class RankedResult:
"""A scored gene set with classification metrics."""
gene_set_id: str
name: str
search_name: str | None
recall: float
precision: float
f1: float
result_count: int
overlap_count: int
[docs]
def rank_gene_sets_by_recall(
gene_sets: list[GeneSetCandidate],
positive_ids: list[str],
negative_ids: list[str] | None = None,
) -> list[RankedResult]:
"""Rank gene sets by recall of *positive_ids*, then by F1 descending.
:param gene_sets: Candidate gene sets to evaluate.
:param positive_ids: Known-positive gene IDs to recover.
:param negative_ids: Optional negative controls (used for precision).
:returns: Sorted list of ranked results (best first).
"""
if not gene_sets:
return []
pos = set(positive_ids)
neg = set(negative_ids) if negative_ids else set[str]()
results: list[RankedResult] = []
for gs in gene_sets:
gs_ids = set(gs.gene_ids)
overlap = gs_ids & pos
overlap_count = len(overlap)
result_count = len(gs_ids)
recall = overlap_count / len(pos) if pos else 0.0
# Precision: of all results, how many are true positives?
# If negatives are provided, a result that's neither positive nor
# negative is ignored (unknown). Without negatives, every non-positive
# result is treated as a false positive.
neg_hits = len(gs_ids & neg) if neg else result_count - overlap_count
tp = overlap_count
fp = neg_hits if neg else result_count - overlap_count
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
f1_denom = precision + recall
f1 = (2 * precision * recall / f1_denom) if f1_denom > 0 else 0.0
results.append(
RankedResult(
gene_set_id=gs.id,
name=gs.name,
search_name=gs.search_name,
recall=recall,
precision=precision,
f1=f1,
result_count=result_count,
overlap_count=overlap_count,
)
)
results.sort(key=lambda r: (-r.recall, -r.f1))
return results