"""WDK parameter fetching, caching, and expansion."""
from typing import Any, cast
from veupath_chatbot.domain.parameters.normalize import ParameterNormalizer
from veupath_chatbot.domain.parameters.specs import (
adapt_param_specs,
extract_param_specs,
find_input_step_param,
unwrap_search_data,
)
from veupath_chatbot.domain.parameters.vocab_utils import flatten_vocab
from veupath_chatbot.integrations.veupathdb.client import VEuPathDBClient
from veupath_chatbot.integrations.veupathdb.discovery import get_discovery_service
from veupath_chatbot.integrations.veupathdb.factory import get_wdk_client
from veupath_chatbot.integrations.veupathdb.param_utils import (
normalize_param_value,
wdk_entity_name,
wdk_search_matches,
)
from veupath_chatbot.platform.errors import ErrorCode, WDKError
from veupath_chatbot.platform.errors import ValidationError as CoreValidationError
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.tool_errors import tool_error
from veupath_chatbot.platform.types import JSONArray, JSONObject, JSONValue
from veupath_chatbot.services.wdk.record_types import resolve_record_type
from .searches import find_record_type_for_search
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Extracted helpers
# ---------------------------------------------------------------------------
def _allowed_values(
vocab: JSONObject | JSONArray | None,
) -> list[JSONObject]:
"""Extract WDK-accepted parameter values from a vocabulary.
Returns ``[{"value": <wdk_value>, "display": <label>}, ...]`` so the LLM
knows both *what to pass* and *what it means*.
:param vocab: Vocabulary tree or flat list from catalog.
:returns: List of value/display dicts (capped at 50).
"""
if not vocab:
return []
entries: list[JSONObject] = []
seen: set[str] = set()
for entry in flatten_vocab(vocab, prefer_term=True):
# Prefer the WDK-accepted value; fall back to display if missing.
candidate = entry.get("value") or entry.get("display")
if not candidate:
continue
text = str(candidate)
if text in seen:
continue
seen.add(text)
display = entry.get("display")
display_str = str(display) if display else text
entries.append({"value": text, "display": display_str})
if len(entries) >= 50:
break
return entries
_PHYLETIC_STRUCTURAL_PARAMS = frozenset({"phyletic_indent_map", "phyletic_term_map"})
_PROFILE_PATTERN_HELP = (
"Phylogenetic profile pattern. Format: %CODE:STATE[:QUANTIFIER]% (percent-delimited).\n"
" CODE = species or group code from lookup_phyletic_codes()\n"
" STATE = Y (present) or N (absent)\n"
" QUANTIFIER = 'any' or 'all' (optional, only matters for group codes)\n"
"\n"
"For leaf species codes (e.g. pfal, hsap), quantifier is ignored:\n"
" pfal:Y → present in P. falciparum\n"
" hsap:N → absent from H. sapiens\n"
"\n"
"For group codes (e.g. MAMM, APIC), quantifier controls expansion:\n"
" MAMM:N → absent from ALL mammals (default for :N)\n"
" MAMM:N:all → same as above (explicit)\n"
" APIC:Y:any → present in ANY Apicomplexa (default for :Y, dropped from pattern)\n"
" APIC:Y:all → present in ALL Apicomplexa (expanded, usually 0 results)\n"
"\n"
"Example: '%MAMM:N%pfal:Y%' → P.falciparum present, all mammals absent\n"
"\n"
"CRITICAL: The 'organism' parameter controls which organisms' genes appear in "
"results. You MUST select ALL relevant organisms (use all leaf values from the "
"organism vocabulary tree, or use the tree's root '@@fake@@' sentinel for 'select all'). "
"If you only select one organism, you will get 0 results even if the pattern is correct."
)
def _render_vocab_tree(
node: JSONObject,
*,
max_lines: int = 80,
_depth: int = 0,
_lines: list[str] | None = None,
) -> list[str]:
"""Render a WDK tree vocabulary as indented text lines.
Each line is ``" " * depth + term``. Stops after *max_lines* to avoid
blowing up the tool response for huge trees.
"""
if _lines is None:
_lines = []
if len(_lines) >= max_lines:
return _lines
from veupath_chatbot.domain.parameters.vocab_utils import (
get_node_term,
get_vocab_children,
)
term = get_node_term(node)
if term and term != "@@fake@@":
_lines.append(f"{' ' * _depth}{term}")
for child in get_vocab_children(node):
if len(_lines) >= max_lines:
_lines.append(" ... (truncated)")
break
_render_vocab_tree(child, max_lines=max_lines, _depth=_depth + 1, _lines=_lines)
return _lines
def _format_param_info(param_specs: JSONArray) -> JSONArray:
"""Build a formatted parameter info array from raw WDK param specs.
Each spec dict is transformed into a normalized info dict with keys:
name, displayName, type, required, isVisible, help, and optionally
allowedValues and defaultValue.
Phyletic structural params (phyletic_indent_map, phyletic_term_map) are
omitted from AI tool output — the model should never set them directly.
The profile_pattern param gets enriched help text with encoding docs.
:param param_specs: Raw parameter spec dicts from WDK.
:returns: Formatted parameter info array.
"""
# Build reverse dependency map: dependent_param → [parent_params]
depends_on: dict[str, list[str]] = {}
controls: dict[str, list[str]] = {}
for spec in param_specs:
if not isinstance(spec, dict):
continue
parent_name_raw = spec.get("name")
parent_name = str(parent_name_raw) if parent_name_raw else ""
dep_params_raw = spec.get("dependentParams")
dep_params = dep_params_raw if isinstance(dep_params_raw, list) else []
if dep_params and parent_name:
dep_strs = [str(d) for d in dep_params]
controls[parent_name] = dep_strs
for dep_str in dep_strs:
depends_on.setdefault(dep_str, []).append(parent_name)
param_info: JSONArray = []
for spec in param_specs:
if not isinstance(spec, dict):
continue
name_raw = spec.get("name")
name = name_raw if isinstance(name_raw, str) else ""
if not name:
continue
# Skip phyletic structural params — model should not set these.
if name in _PHYLETIC_STRUCTURAL_PARAMS:
continue
allow_empty_raw = spec.get("allowEmptyValue")
required = not bool(allow_empty_raw)
display_name_raw = spec.get("displayName")
display_name = display_name_raw if isinstance(display_name_raw, str) else name
type_raw = spec.get("type")
param_type = type_raw if isinstance(type_raw, str) else "string"
help_raw = spec.get("help")
help_text = help_raw if isinstance(help_raw, str) else ""
is_visible_raw = spec.get("isVisible")
is_visible = is_visible_raw if isinstance(is_visible_raw, bool) else True
# Inject enriched help for profile_pattern.
if name == "profile_pattern":
help_text = _PROFILE_PATTERN_HELP
info: JSONObject = {
"name": name,
"displayName": display_name,
"type": param_type,
"required": required,
"isVisible": is_visible,
"help": help_text,
}
vocabulary_raw = spec.get("vocabulary")
vocabulary = (
vocabulary_raw if isinstance(vocabulary_raw, (dict, list)) else None
)
# For tree-vocabulary params, render as indented tree so the model
# can see parent/child relationships and use parent nodes for
# auto-expansion.
if param_type == "multi-pick-vocabulary" and isinstance(vocabulary, dict):
tree_lines = _render_vocab_tree(vocabulary, max_lines=80)
if tree_lines:
info["allowedValues_tree"] = cast(
JSONValue,
"\n".join(tree_lines)
+ "\n(Pass a parent node to auto-select all its children)",
)
else:
allowed_entries = _allowed_values(vocabulary)
if allowed_entries:
info["allowedValues"] = cast(JSONValue, allowed_entries)
initial_display_raw = spec.get("initialDisplayValue")
if initial_display_raw is not None:
info["defaultValue"] = initial_display_raw
default_value_raw = spec.get("defaultValue")
if default_value_raw is not None and "defaultValue" not in info:
info["defaultValue"] = default_value_raw
# Annotate dependency relationships
if name in controls:
info["controlsVocabOf"] = cast(JSONValue, controls[name])
if name in depends_on:
parents = depends_on[name]
info["vocabDependsOn"] = cast(JSONValue, parents)
info["note"] = (
f"The allowed values for this param change based on the value of "
f"{', '.join(parents)}. The values shown here are for the default "
f"context only. Use get_dependent_vocab(search_name, param_name='{name}', "
f"context_values={{'{parents[0]}': '<your chosen value>'}}) to see "
f"the full vocabulary after setting {parents[0]}."
)
param_info.append(info)
return param_info
async def _fetch_search_details(
discovery: Any,
site_id: str,
resolved_record_type: str,
search_name: str,
*,
record_types: list[Any] | None = None,
) -> tuple[JSONObject, str]:
"""Fetch search details, falling back to scanning all record types.
:param discovery: Discovery service instance.
:param site_id: Site identifier.
:param resolved_record_type: Record type to try first.
:param search_name: Name of the search.
:param record_types: All available record types (for fallback scan).
:returns: Tuple of (details dict, resolved record type).
:raises CoreValidationError: When the search cannot be found.
"""
try:
details = await discovery.get_search_details(
site_id, resolved_record_type, search_name, expand_params=True
)
return details, resolved_record_type
except Exception as e:
return await _fallback_scan_record_types(
discovery,
site_id,
resolved_record_type,
search_name,
record_types=record_types or [],
original_error=e,
)
async def _fallback_scan_record_types(
discovery: Any,
site_id: str,
resolved_record_type: str,
search_name: str,
*,
record_types: list[Any],
original_error: Exception,
) -> tuple[JSONObject, str]:
"""Scan all record types trying to find the search, raising if not found."""
details: JSONObject | None = None
for rt in record_types:
if not isinstance(rt, dict):
continue
rt_name = wdk_entity_name(rt)
if not rt_name:
continue
searches = await discovery.get_searches(site_id, rt_name)
if any(wdk_search_matches(s, search_name) for s in searches):
resolved_record_type = rt_name
try:
details = await discovery.get_search_details(
site_id, rt_name, search_name, expand_params=True
)
except Exception:
details = None
break
if details is None:
available = await discovery.get_searches(site_id, resolved_record_type)
available_searches: list[str] = [
name
for s in available
if isinstance(s, dict) and (name := wdk_entity_name(s))
]
error_dict: JSONObject = {
"path": "searchName",
"message": f"Search not found: {search_name}",
"code": ErrorCode.SEARCH_NOT_FOUND.value,
"recordType": resolved_record_type,
"searchName": search_name,
"availableSearches": cast(JSONValue, available_searches),
"details": str(original_error),
}
raise CoreValidationError(
title="Search not found",
detail=f"Search not found: {search_name}",
errors=[error_dict],
) from original_error
return details, resolved_record_type
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
async def get_search_parameters(
site_id: str,
record_type: str,
search_name: str,
) -> JSONObject:
"""Get detailed parameter info for a specific search.
This is intentionally defensive: WDK responses can vary by site/endpoint.
"""
discovery = get_discovery_service()
resolved_record_type = record_type
record_types = await discovery.get_record_types(site_id)
if record_type:
resolved = resolve_record_type(record_types, record_type)
if resolved:
resolved_record_type = resolved
details, resolved_record_type = await _fetch_search_details(
discovery,
site_id,
resolved_record_type,
search_name,
record_types=record_types,
)
details = unwrap_search_data(details) or details
param_specs = extract_param_specs(details if isinstance(details, dict) else {})
param_info = _format_param_info(param_specs)
details_display_name = search_name
details_description = ""
if isinstance(details, dict):
display_name_raw = details.get("displayName")
if isinstance(display_name_raw, str):
details_display_name = display_name_raw
description_raw = details.get("description")
if isinstance(description_raw, str):
details_description = description_raw
return {
"searchName": search_name,
"displayName": details_display_name,
"description": details_description,
"parameters": param_info,
"resolvedRecordType": resolved_record_type,
}
[docs]
async def lookup_phyletic_codes(
site_id: str,
record_type: str,
query: str,
) -> JSONObject:
"""Search phyletic species codes by name for the GenesByOrthologPattern search.
Returns matching ``{code, label}`` pairs from the ``phyletic_term_map``
vocabulary. The model uses codes to build ``profile_pattern`` values.
:param site_id: Site ID.
:param record_type: Record type (usually "transcript").
:param query: Species/clade name search term (case-insensitive substring).
:returns: Dict with ``matches`` list and ``query`` echo.
"""
try:
discovery = get_discovery_service()
record_types = await discovery.get_record_types(site_id)
resolved = resolve_record_type(record_types, record_type) or record_type
details, _ = await _fetch_search_details(
discovery,
site_id,
resolved,
"GenesByOrthologPattern",
record_types=record_types,
)
details = unwrap_search_data(details) or details
specs = extract_param_specs(details if isinstance(details, dict) else {})
term_map_vocab: JSONArray = []
indent_map_vocab: JSONArray = []
for spec in specs:
if not isinstance(spec, dict):
continue
name = spec.get("name")
if name == "phyletic_term_map":
vocab = spec.get("vocabulary")
if isinstance(vocab, list):
term_map_vocab = vocab
elif name == "phyletic_indent_map":
vocab = spec.get("vocabulary")
if isinstance(vocab, list):
indent_map_vocab = vocab
# Build a set of group codes (codes that have children = non-leaf).
# indent_map entries are [code, depth, null]. A code is a group if
# the *next* entry has a strictly greater depth.
group_codes: set[str] = set()
for i, entry in enumerate(indent_map_vocab):
if not isinstance(entry, list) or len(entry) < 2:
continue
code = str(entry[0])
depth = int(str(entry[1])) if entry[1] is not None else 0
if i + 1 < len(indent_map_vocab):
nxt = indent_map_vocab[i + 1]
if isinstance(nxt, list) and len(nxt) >= 2:
next_depth = int(str(nxt[1])) if nxt[1] is not None else 0
if next_depth > depth:
group_codes.add(code)
q = query.lower().strip()
matches: list[JSONObject] = []
for entry in term_map_vocab:
if not isinstance(entry, list) or len(entry) < 2:
continue
code = str(entry[0])
label = str(entry[1])
if code == "ALL":
continue
if q in label.lower() or q in code.lower():
is_leaf = code not in group_codes
matches.append({"code": code, "label": label, "leaf": is_leaf})
if len(matches) >= 20:
break
return {
"query": query,
"matches": cast(JSONValue, matches),
"total": len(matches),
"hint": (
"Use codes in profile_pattern: %CODE:Y% (include) or %CODE:N% (exclude). "
"Example: '%MAMM:N%pfal:Y%'. "
"Group codes (leaf=false) support optional quantifier: "
"MAMM:N:all (absent from all, default for :N), "
"APIC:Y:any (present in any, default for :Y). "
"Leaf codes need no quantifier."
),
}
except Exception as exc:
return tool_error(
ErrorCode.INTERNAL_ERROR,
f"Failed to look up phyletic codes: {exc}",
)
[docs]
async def expand_search_details_with_params(
site_id: str,
record_type: str,
search_name: str,
context_values: JSONObject | None,
) -> JSONObject:
"""Return WDK search details after applying (WDK-wire) context values.
NOTE: despite the historical name, this is *not* a pure validation API; it returns
WDK search details payload. Keep it separate from the public validation endpoint.
"""
client = get_wdk_client(site_id)
raw_context = context_values or {}
normalized_context: JSONObject = {}
details: JSONObject | None = None
allowed: set[str] = set()
details, allowed = await _load_discovery_details_and_allowed(
site_id=site_id,
record_type=record_type,
search_name=search_name,
)
filtered_context = _filter_context_values(raw_context, allowed)
details_unwrapped = unwrap_search_data(details)
specs = adapt_param_specs(details_unwrapped) if details_unwrapped else {}
if specs:
normalizer = ParameterNormalizer(specs)
try:
normalized_context = normalizer.normalize(filtered_context)
except CoreValidationError:
try:
resolved_record_type = await find_record_type_for_search(
site_id, record_type, search_name
)
details = await client.get_search_details_with_params(
resolved_record_type,
search_name,
filtered_context,
expand_params=True,
)
except Exception:
raise
details = unwrap_search_data(details) or details
specs = adapt_param_specs(details if isinstance(details, dict) else {})
normalizer = ParameterNormalizer(specs)
normalized_context = normalizer.normalize(filtered_context)
input_step_param = find_input_step_param(specs)
if input_step_param:
normalized_context[input_step_param] = ""
else:
normalized_context = {
key: normalize_param_value(value) for key, value in filtered_context.items()
}
resolved_record_type = await find_record_type_for_search(
site_id, record_type, search_name
)
return await _get_search_details_with_portal_fallback(
site_id=site_id,
client=client,
record_type=resolved_record_type,
search_name=search_name,
context_values=normalized_context,
)
def _filter_context_values(raw_context: JSONObject, allowed: set[str]) -> JSONObject:
"""Filter context values to keys WDK recognizes for the search (best-effort).
:param raw_context: Raw context from request.
:param allowed: Set of allowed parameter names.
:returns: Filtered context dict.
"""
return (
{key: value for key, value in raw_context.items() if key in allowed}
if allowed
else dict(raw_context)
)
async def _load_discovery_details_and_allowed(
*, site_id: str, record_type: str, search_name: str
) -> tuple[JSONObject | None, set[str]]:
"""Load discovery search details + extract allowed param names (best-effort)."""
try:
discovery = get_discovery_service()
details = await discovery.get_search_details(
site_id, record_type, search_name, expand_params=True
)
return (
details,
_extract_param_names(details if isinstance(details, dict) else {}),
)
except Exception as exc:
logger.warning(
"Failed to load discovery details for param resolution",
site_id=site_id,
record_type=record_type,
search_name=search_name,
error=str(exc),
)
return None, set()
async def _get_search_details_with_portal_fallback(
*,
site_id: str,
client: VEuPathDBClient,
record_type: str,
search_name: str,
context_values: JSONObject,
) -> JSONObject:
"""Call WDK contextual search details, falling back to portal when appropriate."""
try:
return await client.get_search_details_with_params(
record_type,
search_name,
context_values,
)
except WDKError:
if site_id != "veupathdb":
portal_client = get_wdk_client("veupathdb")
return await portal_client.get_search_details_with_params(
record_type,
search_name,
context_values,
)
raise
[docs]
async def get_refreshed_dependent_params(
*,
site_id: str,
record_type: str,
search_name: str,
parameter_name: str,
context_values: JSONObject,
) -> JSONObject:
"""Get refreshed dependent parameter vocabulary, falling back to the portal.
Tries the site-specific WDK client first. If that fails with a
``WDKError`` and the site is not already ``veupathdb``, retries against
the portal client (``veupathdb``).
:param site_id: Site identifier.
:param record_type: WDK record type.
:param search_name: WDK search name.
:param parameter_name: The dependent parameter to refresh.
:param context_values: Current context parameter values.
:returns: Refreshed dependent param payload from WDK.
"""
client = get_wdk_client(site_id)
try:
return await client.get_refreshed_dependent_params(
record_type,
search_name,
parameter_name,
context_values,
)
except WDKError:
if site_id != "veupathdb":
portal_client = get_wdk_client("veupathdb")
return await portal_client.get_refreshed_dependent_params(
record_type,
search_name,
parameter_name,
context_values,
)
raise
def _names_from_param_list(params: list[JSONValue]) -> set[str]:
"""Extract name strings from a list of param dicts."""
names: set[str] = set()
for p in params:
if isinstance(p, dict):
name = p.get("name")
if isinstance(name, str):
names.add(name)
return names
def _extract_param_names(details: JSONObject) -> set[str]:
"""Extract parameter names from WDK search details.
Checks ``details.searchData.parameters`` first, then ``details.parameters``.
"""
if not isinstance(details, dict):
return set()
unwrapped = unwrap_search_data(details) or details
params = unwrapped.get("parameters") if isinstance(unwrapped, dict) else None
if isinstance(params, list):
return _names_from_param_list(params)
if isinstance(params, dict):
return {k for k in params if k}
return set()