"""Discovery/search helper tools (AI-exposed)."""
import re
from typing import Annotated
import httpx
from kani import AIParam, ai_function
from veupath_chatbot.domain.strategy.explain import explain_operation
from veupath_chatbot.domain.strategy.ops import parse_op
from veupath_chatbot.platform.errors import ErrorCode
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.tool_errors import tool_error
from veupath_chatbot.platform.types import (
JSONArray,
JSONObject,
JSONValue,
as_json_object,
)
from veupath_chatbot.services.catalog.searches import (
get_raw_record_types,
get_raw_searches,
)
from veupath_chatbot.services.strategies.engine.helpers import StrategyToolsHelpers
logger = get_logger(__name__)
[docs]
class StrategyDiscoveryOps(StrategyToolsHelpers):
"""Discovery/search tools."""
[docs]
@ai_function()
async def search_searches_by_keywords(
self,
keywords: Annotated[
list[str] | str,
AIParam(desc="Keywords to match (e.g., ['otto', '2014', 'gametocyte'])"),
],
record_type: Annotated[
str | None, AIParam(desc="Optional record type to restrict the search")
] = None,
limit: Annotated[int, AIParam(desc="Max number of results")] = 20,
) -> JSONObject:
"""Search available questions by keywords across name/display/description."""
if isinstance(keywords, str):
raw_terms = re.findall(r"[A-Za-z0-9]+", keywords)
else:
raw_terms_list: list[str] = []
for item in keywords:
raw_terms_list.extend(re.findall(r"[A-Za-z0-9]+", str(item)))
raw_terms = raw_terms_list
terms = [t.lower() for t in raw_terms if t]
if not terms:
return tool_error(
ErrorCode.VALIDATION_ERROR, "No keywords provided", keywords=[]
)
matches: JSONArray = []
record_types: list[str] = []
resolved_record_type = (
await self._resolve_record_type(record_type) if record_type else None
)
if resolved_record_type:
record_types = [resolved_record_type]
else:
record_types_list: list[str] = []
record_types_raw = await get_raw_record_types(self.session.site_id)
for rt_value in record_types_raw:
if not isinstance(rt_value, dict):
continue
rt = as_json_object(rt_value)
url_segment = rt.get("urlSegment")
name_value = rt.get("name")
rt_name: str | None = None
if isinstance(url_segment, str):
rt_name = url_segment
elif isinstance(name_value, str):
rt_name = name_value
if rt_name:
record_types_list.append(rt_name)
record_types = record_types_list
for rt_name in record_types:
try:
searches = await get_raw_searches(self.session.site_id, rt_name)
except httpx.HTTPError, KeyError:
logger.warning(
"Failed to fetch searches for record type %s",
rt_name,
exc_info=True,
)
continue
for search_value in searches:
if not isinstance(search_value, dict):
continue
search = as_json_object(search_value)
url_segment_value = search.get("urlSegment")
name_value = search.get("name")
display_value = search.get("displayName")
short_value = search.get("shortDisplayName")
description_value = search.get("description")
name = (
str(url_segment_value)
if isinstance(url_segment_value, str)
else (str(name_value) if isinstance(name_value, str) else "")
)
display = str(display_value) if isinstance(display_value, str) else ""
short = str(short_value) if isinstance(short_value, str) else ""
description = (
str(description_value) if isinstance(description_value, str) else ""
)
haystack = " ".join([name, display, short, description]).lower()
score = sum(1 for term in terms if term in haystack)
if score == 0:
continue
matches.append(
{
"recordType": rt_name,
"searchName": name or display,
"displayName": display or name,
"description": description,
"score": score,
"matchedKeywords": [t for t in terms if t in haystack],
}
)
def sort_key(item_value: JSONValue) -> tuple[int, str]:
if not isinstance(item_value, dict):
return (0, "")
item = as_json_object(item_value)
score_value = item.get("score")
display_name_value = item.get("displayName")
score = int(score_value) if isinstance(score_value, (int, float)) else 0
display_name = (
str(display_name_value) if isinstance(display_name_value, str) else ""
)
return (-score, display_name)
matches.sort(key=sort_key)
return {"keywords": terms, "results": matches[: max(limit, 1)]}
[docs]
@ai_function()
async def explain_operator(
self,
operator: Annotated[
str,
AIParam(desc="Operator to explain (INTERSECT, UNION, etc.)"),
],
) -> str:
"""Explain what a combine operator does."""
op = parse_op(operator)
return explain_operation(op)