Source code for veupath_chatbot.services.search_rerank

"""Reusable search result reranking utilities.

Implements a "fetch wide, rerank narrow" pattern for VEuPathDB search:

1. **Analyse** the query to detect intent (gene ID prefix, organism
   abbreviation, free text, etc.)
2. **Fetch** broadly from one or more sources (site-search, WDK).
3. **Score** each result on multiple relevance signals.
4. **Deduplicate** by primary key, keeping the highest-scored entry.
5. **Return** the top-N results sorted by combined score.
"""

import re
from collections.abc import Callable, Sequence
from dataclasses import dataclass

from rapidfuzz import fuzz

from veupath_chatbot.platform.types import JSONObject


[docs] def score_text_match(query: str, value: str) -> float: """Score how well *query* matches *value* (0.0--1.0). Uses ``rapidfuzz`` for robust fuzzy matching, with bonuses for exact and prefix matches that are critical for gene ID lookups. """ q = query.strip().lower() v = value.strip().lower() if not q or not v: return 0.0 if q == v: return 1.0 if v.startswith(q): return 0.95 if q in v: return 0.80 # rapidfuzz.fuzz.WRatio handles partial, token-sort, and token-set # ratios internally and returns the best score (0–100). return fuzz.WRatio(q, v) / 100.0
PRIMARY_MATCH_FIELDS: frozenset[str] = frozenset( { "gene_source_id", "gene_name", "gene_product", "gene_type", "gene_organism_full", "primary_key", "hyperlinkName", } ) SECONDARY_MATCH_FIELDS: frozenset[str] = frozenset( { "gene_Notes", "gene_PubMed", "gene_UserCommentContent", "autocomplete", "MULTIgene_Notes", "MULTIgene_PubMed", } )
[docs] def score_field_quality(matched_fields: Sequence[str]) -> float: """Score based on *which* fields the query matched in.""" if not matched_fields: return 0.0 if any(f in PRIMARY_MATCH_FIELDS for f in matched_fields): return 1.0 if any(f in SECONDARY_MATCH_FIELDS for f in matched_fields): return -0.5 return 0.0
[docs] @dataclass class ScoredResult: """A search result with an attached relevance score.""" result: JSONObject score: float source: str = ""
[docs] def dedup_and_sort( results: Sequence[ScoredResult], key_fn: Callable[[JSONObject], str], ) -> list[ScoredResult]: """Deduplicate results by key, keeping the highest-scoring entry.""" best: dict[str, ScoredResult] = {} for sr in results: k = key_fn(sr.result) if not k: continue existing = best.get(k) if existing is None or sr.score > existing.score: best[k] = sr return sorted( best.values(), key=lambda x: (-x.score, key_fn(x.result)), )
_GENE_ID_PREFIX_RE = re.compile( r"^[A-Za-z]{2,8}[_\-]?\d", re.IGNORECASE, )
[docs] @dataclass(frozen=True) class QueryIntent: """What we think the user is looking for.""" raw: str is_gene_id_like: bool = False implied_organism: str | None = None implied_organism_score: float = 0.0 wildcard_ids: tuple[str, ...] = ()
def _build_wildcard_ids(query: str) -> tuple[str, ...]: """Generate wildcard ID patterns for a gene-ID-like query.""" q = query.strip() if not q: return () patterns: list[str] = [] if "_" in q: patterns.append(f"{q}*") else: upper = q.upper() patterns.append(f"{upper}_*") patterns.append(f"{upper}*") if upper != q: patterns.append(f"{q}*") return tuple(dict.fromkeys(patterns))
[docs] def analyse_query( query: str, available_organisms: list[str], organism_scorer: Callable[[str, str], float] | None = None, ) -> QueryIntent: """Analyse a query string to detect search intent. :param query: User's raw search text. :param available_organisms: Canonical organism names from the site. :param organism_scorer: A ``(query, organism) -> float`` scorer. :returns: A :class:`QueryIntent` describing what the user likely wants. """ q = query.strip() if not q: return QueryIntent(raw=q) scorer = organism_scorer or _default_organism_scorer is_id_like = bool(_GENE_ID_PREFIX_RE.match(q)) best_org: str | None = None best_score: float = 0.0 for org in available_organisms: s = scorer(q, org) if s > best_score: best_score = s best_org = org if best_score < 0.60: best_org = None best_score = 0.0 wildcard_ids = _build_wildcard_ids(q) if is_id_like else () return QueryIntent( raw=q, is_gene_id_like=is_id_like, implied_organism=best_org, implied_organism_score=best_score, wildcard_ids=wildcard_ids, )
def _default_organism_scorer(query: str, organism: str) -> float: """Fallback organism scorer -- simple substring check.""" q = query.strip().lower() o = organism.strip().lower() if q == o: return 1.0 if q in o: return 0.7 return 0.0