"""Discovery and caching of record types, searches, and parameters."""
import asyncio
import threading
from veupath_chatbot.integrations.veupathdb.client import VEuPathDBClient
from veupath_chatbot.integrations.veupathdb.param_utils import (
wdk_entity_name,
wdk_search_matches,
)
from veupath_chatbot.integrations.veupathdb.site_router import get_site_router
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONArray, JSONObject
logger = get_logger(__name__)
[docs]
class SearchCatalog:
"""Cached catalog of searches for a site."""
[docs]
def __init__(self, site_id: str) -> None:
self.site_id = site_id
self._record_types: JSONArray = []
self._searches: dict[str, JSONArray] = {}
self._search_details: dict[str, JSONObject] = {}
self._loaded = False
self._lock = asyncio.Lock()
[docs]
async def load(self, client: VEuPathDBClient) -> None:
"""Load catalog from VEuPathDB."""
async with self._lock:
if self._loaded:
return
logger.info("Loading search catalog", site_id=self.site_id)
try:
# Load record types with expanded searches when possible
raw_record_types = await client.get_record_types(expanded=True)
# WDK's record-types endpoint returns an array directly, but
# some deployments may wrap it under JsonKeys.RECORD_TYPES = "recordTypes".
if isinstance(raw_record_types, dict):
wrapped = raw_record_types.get("recordTypes")
if isinstance(wrapped, list):
raw_record_types = wrapped
else:
raise ValueError(
f"Unexpected record-types response shape: "
f"dict without 'recordTypes' list (keys: {list(raw_record_types.keys())})"
)
expanded_supported = any(
isinstance(rt, dict) and "searches" in rt for rt in raw_record_types
)
# Handle both list of strings and list of dicts
for rt in raw_record_types:
if isinstance(rt, str):
rt_name = rt
self._record_types.append({"urlSegment": rt, "name": rt})
searches: JSONArray | None = []
elif isinstance(rt, dict):
rt_dict: JSONObject = rt
rt_name = wdk_entity_name(rt_dict)
self._record_types.append(rt_dict)
searches_raw = (
rt_dict.get("searches") if expanded_supported else None
)
if isinstance(searches_raw, list):
searches = searches_raw
else:
searches = None
else:
continue
if rt_name:
if searches is not None and searches != []:
self._searches[rt_name] = searches
else:
try:
searches = await client.get_searches(rt_name)
self._searches[rt_name] = searches
except Exception as e:
logger.warning(
"Failed to load searches",
record_type=rt_name,
error=str(e),
)
self._loaded = True
logger.info(
"Search catalog loaded",
site_id=self.site_id,
record_types=len(self._record_types),
total_searches=sum(len(s) for s in self._searches.values()),
)
except Exception as e:
logger.error(
"Failed to load catalog", site_id=self.site_id, error=str(e)
)
raise
[docs]
def get_record_types(self) -> JSONArray:
"""Get all record types."""
return self._record_types
[docs]
def get_searches(self, record_type: str) -> JSONArray:
"""Get searches for a record type.
:param record_type: WDK record type.
"""
return self._searches.get(record_type, [])
[docs]
def find_search(self, record_type: str, search_name: str) -> JSONObject | None:
"""Find a specific search.
:param record_type: WDK record type.
:param search_name: WDK search name.
"""
searches = self.get_searches(record_type)
for search in searches:
if not isinstance(search, dict):
continue
if wdk_entity_name(search) == search_name:
return search
return None
[docs]
def find_record_type_for_search(self, search_name: str) -> str | None:
"""Find which record type owns a search (global lookup).
Mirrors WDK's ``WdkModel.getQuestionByName()`` — iterates all cached
record types to find the one containing the given search.
:param search_name: WDK search name (urlSegment or name).
:returns: The record type name, or None if not found.
"""
for rt_name, searches in self._searches.items():
if any(wdk_search_matches(s, search_name) for s in searches):
return rt_name
return None
[docs]
async def get_search_details(
self,
client: VEuPathDBClient,
record_type: str,
search_name: str,
expand_params: bool = True,
) -> JSONObject:
"""Get detailed search config with caching."""
cache_key = f"{record_type}/{search_name}?expand={int(expand_params)}"
if cache_key not in self._search_details:
details = await client.get_search_details(
record_type, search_name, expand_params=expand_params
)
self._search_details[cache_key] = details
return self._search_details[cache_key]
[docs]
class DiscoveryService:
"""Service for discovering and caching site metadata."""
[docs]
def __init__(self) -> None:
self._catalogs: dict[str, SearchCatalog] = {}
self._lock = asyncio.Lock()
[docs]
async def get_catalog(self, site_id: str) -> SearchCatalog:
"""Get or create catalog for a site."""
async with self._lock:
if site_id not in self._catalogs:
self._catalogs[site_id] = SearchCatalog(site_id)
catalog = self._catalogs[site_id]
router = get_site_router()
client = router.get_client(site_id)
await catalog.load(client)
return catalog
[docs]
async def get_record_types(self, site_id: str) -> JSONArray:
"""Get record types for a site."""
catalog = await self.get_catalog(site_id)
return catalog.get_record_types()
[docs]
async def get_searches(self, site_id: str, record_type: str) -> JSONArray:
"""Get searches for a record type."""
catalog = await self.get_catalog(site_id)
return catalog.get_searches(record_type)
[docs]
async def get_search_details(
self,
site_id: str,
record_type: str,
search_name: str,
expand_params: bool = True,
) -> JSONObject:
"""Get detailed search configuration."""
catalog = await self.get_catalog(site_id)
router = get_site_router()
client = router.get_client(site_id)
return await catalog.get_search_details(
client,
record_type,
search_name,
expand_params=expand_params,
)
[docs]
async def preload_all(self) -> None:
"""Preload catalogs for all sites."""
router = get_site_router()
sites = router.list_sites()
async def load_site(site_id: str) -> None:
try:
await self.get_catalog(site_id)
except Exception as e:
logger.warning("Failed to preload site", site_id=site_id, error=str(e))
await asyncio.gather(*[load_site(s.id) for s in sites])
# Global discovery service
_discovery: DiscoveryService | None = None
_discovery_lock = threading.Lock()
[docs]
def get_discovery_service() -> DiscoveryService:
"""Get the global discovery service."""
global _discovery
if _discovery is not None:
return _discovery
with _discovery_lock:
if _discovery is None:
_discovery = DiscoveryService()
return _discovery