Source code for veupath_chatbot.services.catalog.searches

"""Search listing and searching functions."""

import math
import re
from collections import Counter

from veupath_chatbot.domain.strategy.ast import PlanStepNode
from veupath_chatbot.domain.strategy.compile import ResolveRecordType
from veupath_chatbot.domain.strategy.tree import collect_plan_leaves
from veupath_chatbot.integrations.veupathdb.discovery import get_discovery_service
from veupath_chatbot.integrations.veupathdb.param_utils import wdk_entity_name
from veupath_chatbot.integrations.veupathdb.site_search import (
    query_site_search,
    strip_html_tags,
)
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONArray, JSONObject

logger = get_logger(__name__)

# ---------------------------------------------------------------------------
# Scoring constants
# ---------------------------------------------------------------------------

_WEIGHT_SEARCH_NAME = 5.0
_WEIGHT_DISPLAY_NAME = 3.0
_WEIGHT_DESCRIPTION = 1.0
_KEYWORD_BOOST = 20.0
_MIN_TERM_LEN = 3

_RECORD_CLASS_LABELS = {
    "transcript": "genes/transcripts",
    "gene": "genes",
    "snp": "SNPs",
    "popsetsequence": "popset sequences",
    "est": "ESTs",
    "compound": "compounds",
    "pathway": "pathways",
}


# ---------------------------------------------------------------------------
# Scoring, filtering, annotation
# ---------------------------------------------------------------------------














[docs] async def get_raw_record_types(site_id: str) -> JSONArray: """Return raw WDK record type objects for a site. Unlike :func:`services.catalog.sites.get_record_types`, this preserves the full WDK payloads (``urlSegment``, ``name``, ``displayName``, etc.) so that callers needing the original structure don't have to go through the integrations layer directly. """ discovery = get_discovery_service() return await discovery.get_record_types(site_id)
[docs] async def get_raw_searches(site_id: str, record_type: str) -> JSONArray: """Return raw WDK search objects for a record type. Thin service-level wrapper over the discovery integration so that AI tools and other service consumers never import from ``integrations/`` directly. """ discovery = get_discovery_service() return await discovery.get_searches(site_id, record_type)
[docs] async def list_searches(site_id: str, record_type: str) -> list[dict[str, str]]: """List searches for a specific record type. Returns **name + displayName only** to keep the payload small (VEuPathDB has 2000+ searches; descriptions alone add ~3 MB). The model should use ``search_for_searches`` for targeted discovery with descriptions, or ``get_search_parameters`` for full details on a specific search. """ discovery = get_discovery_service() searches = await discovery.get_searches(site_id, record_type) result: list[dict[str, str]] = [] for s in searches: if not isinstance(s, dict): continue is_internal_raw = s.get("isInternal") if isinstance(is_internal_raw, bool) and is_internal_raw: continue search_name = wdk_entity_name(s) display_name_raw = s.get("displayName") display_name = display_name_raw if isinstance(display_name_raw, str) else "" result.append( { "name": search_name, "displayName": display_name, } ) return result
[docs] async def list_transforms(site_id: str, record_type: str) -> list[dict[str, str]]: """List transform/combine searches (with descriptions). Returns only searches that accept an input step — these are used to chain steps together (ortholog transform, weight filter, span logic, boolean combine, etc.). Typically 5-7 per site, so descriptions are included. """ discovery = get_discovery_service() searches = await discovery.get_searches(site_id, record_type) result: list[dict[str, str]] = [] for s in searches: if not isinstance(s, dict): continue allowed = s.get("allowedPrimaryInputRecordClassNames") if not isinstance(allowed, list) or not allowed: continue is_internal_raw = s.get("isInternal") if isinstance(is_internal_raw, bool) and is_internal_raw: continue search_name = wdk_entity_name(s) display_name_raw = s.get("displayName") display_name = display_name_raw if isinstance(display_name_raw, str) else "" description_raw = s.get("description") description = description_raw if isinstance(description_raw, str) else "" result.append( { "name": search_name, "displayName": display_name, "description": description, } ) return result
async def _search_for_searches_via_site_search( site_id: str, query: str, *, limit: int = 20, ) -> list[dict[str, str]]: """Search WDK searches via the site's /site-search service. This mirrors the webapp search UI (`/app/search`) when filtering to documentType=search. """ try: data = await query_site_search( site_id, search_text=query, document_type="search", limit=limit, offset=0, ) except Exception as exc: logger.warning( "Site-search lookup failed; falling back to discovery search", site_id=site_id, error=str(exc), ) return [] data_dict = data if isinstance(data, dict) else {} search_results_raw = data_dict.get("searchResults") search_results = search_results_raw if isinstance(search_results_raw, dict) else {} docs_raw = search_results.get("documents") docs = docs_raw if isinstance(docs_raw, list) else [] results: list[dict[str, str]] = [] for doc in docs: if not isinstance(doc, dict): continue primary_key = doc.get("primaryKey") if not isinstance(primary_key, list) or len(primary_key) < 2: continue search_name = str(primary_key[0] or "").strip() record_type = str(primary_key[1] or "").strip() if not search_name or not record_type: continue found = doc.get("foundInFields") or {} display = doc.get("hyperlinkName") or "" if not display and isinstance(found, dict): candidates = ( found.get("TEXT__search_displayName") or found.get("autocomplete") or [] ) if isinstance(candidates, list) and candidates: first_candidate = candidates[0] display = str(first_candidate) if first_candidate is not None else "" display_name = strip_html_tags(str(display or "")) or search_name desc_val = "" if isinstance(found, dict): descs = ( found.get("TEXT__search_description") or found.get("TEXT__search_summary") or [] ) if isinstance(descs, list) and descs: desc_val = str(descs[0] or "") description = strip_html_tags(desc_val) results.append( { "name": search_name, "displayName": display_name, "description": description, "recordType": record_type, } ) # Boost transcript/gene results to the top — the model almost always # builds gene strategies, so EST/Popset/compound matches are noise. results.sort(key=lambda r: _record_type_priority(r.get("recordType", ""))) # Deduplicate: same search can appear for multiple record types; # keep only the highest-priority (lowest sort key) occurrence. seen: set[str] = set() deduped: list[dict[str, str]] = [] for r in results: if r["name"] not in seen: seen.add(r["name"]) deduped.append(r) return deduped[:limit] # Record types the model cares about most, in priority order. _PREFERRED_RECORD_TYPES = ("transcript", "gene") def _record_type_priority(record_type: str) -> int: """Lower = higher priority. Transcript/gene first, everything else after.""" rt = record_type.lower() for i, preferred in enumerate(_PREFERRED_RECORD_TYPES): if preferred in rt: return i return 100
[docs] async def search_for_searches( site_id: str, record_type: str | list[str] | None, query: str, *, keywords: list[str] | None = None, limit: int = 20, ) -> list[dict[str, str]]: """Find searches matching a query and/or keywords. Uses field-weighted scoring with IDF, keyword boosting against search names, chooser filtering, and result annotation. Site-search results are merged in when available. """ kw_list = keywords or [] discovery = get_discovery_service() # --- Resolve record types --- record_types: list[str] = [] if isinstance(record_type, list): record_types = [str(rt) for rt in record_type if rt] elif isinstance(record_type, str) and record_type: record_types = [record_type] record_types = list(dict.fromkeys(record_types)) if not record_types: record_types_raw = await discovery.get_record_types(site_id) record_types = [ name for rt in record_types_raw if (name := wdk_entity_name(rt)) ] # --- Collect all candidate searches --- raw_terms = re.findall(r"[A-Za-z0-9_]+", query or "") terms = [t.lower() for t in raw_terms if t] candidates: list[tuple[JSONObject, str]] = [] # (raw_search, record_type) for rt_name in record_types: searches = await discovery.get_searches(site_id, rt_name) for s in searches: if not isinstance(s, dict): continue is_internal = s.get("isInternal") if isinstance(is_internal, bool) and is_internal: continue if is_chooser_search(s): continue candidates.append((s, rt_name)) # --- Build corpus term counts for IDF --- corpus_counts: Counter[str] = Counter() for s, _ in candidates: name = wdk_entity_name(s) display = s.get("displayName", "") desc = s.get("description", "") haystack = f"{name} {display} {desc}".lower() for term in terms: if len(term) >= _MIN_TERM_LEN and term in haystack: corpus_counts[term] += 1 doc_count = len(candidates) # --- Score each candidate --- scored: list[tuple[float, dict[str, str]]] = [] for s, rt_name in candidates: canonical_name = wdk_entity_name(s) display_raw = s.get("displayName") display = display_raw if isinstance(display_raw, str) else canonical_name desc_raw = s.get("description") desc = desc_raw if isinstance(desc_raw, str) else "" sc = score_search( query_terms=terms, keywords=kw_list, search_name=canonical_name, display_name=display, description=desc, corpus_doc_count=doc_count, corpus_term_counts=dict(corpus_counts), ) if sc <= 0: continue annotations = annotate_search(s) entry: dict[str, str] = { "name": canonical_name, "displayName": display, "description": desc, "recordType": rt_name, } entry.update(annotations) scored.append((sc, entry)) # --- Merge site-search results (supplementary boost) --- # Only boost entries that already scored > 0 on keyword matching. # Don't add new entries from site-search that didn't match our scoring. try: site_results = await _search_for_searches_via_site_search( site_id, query, limit=limit ) site_bonus: dict[str, float] = {} for rank, sr in enumerate(site_results): name = sr.get("name", "") if name and name not in site_bonus: site_bonus[name] = 5.0 / (1 + rank) for i, (sc, entry) in enumerate(scored): bonus = site_bonus.get(entry["name"], 0.0) if bonus > 0: scored[i] = (sc + bonus, entry) except Exception: logger.debug("Site-search merge failed (non-fatal)") # --- Sort by score desc, then record type priority --- scored.sort( key=lambda item: ( -item[0], _record_type_priority(item[1].get("recordType", "")), item[1].get("displayName", ""), ) ) # --- Deduplicate and cap --- seen: set[str] = set() result: list[dict[str, str]] = [] for _, entry in scored: name = entry.get("name", "") if name in seen: continue seen.add(name) result.append(entry) if len(result) >= limit: break return result
[docs] async def make_record_type_resolver(site_id: str) -> ResolveRecordType: """Create a record type resolver backed by the pre-cached SearchCatalog. Mirrors WDK's ``WdkModel.getQuestionByName()`` — a global lookup that finds which record type owns a given search name, using the already-cached catalog data (no HTTP calls at resolve time). """ discovery = get_discovery_service() catalog = await discovery.get_catalog(site_id) async def resolve(search_name: str) -> str | None: return catalog.find_record_type_for_search(search_name) return resolve
[docs] async def resolve_record_type_from_steps( root_step: PlanStepNode, resolver: ResolveRecordType, ) -> str | None: """Resolve record type from the first resolvable leaf search in a step tree. Uses :func:`collect_plan_leaves` to find leaf (search) nodes, then calls the resolver to find the owning record type for the first one that resolves. """ for leaf in collect_plan_leaves(root_step): resolved = await resolver(leaf.search_name) if resolved: return resolved return None