Source code for veupath_chatbot.ai.tools.planner.optimization_tools
"""Planner-mode tool for search parameter optimization.
Provides :class:`OptimizationToolsMixin` with the long-running
``optimize_search_parameters`` tool.
"""
import json
from typing import Annotated, cast
from kani import AIParam, ai_function
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.experiment.types import (
ControlValueFormat,
OptimizationObjective,
)
from veupath_chatbot.services.parameter_optimization import (
OptimizationConfig,
OptimizationMethod,
ParameterSpec,
)
from veupath_chatbot.services.parameter_optimization import (
optimize_search_parameters as _run_optimization,
)
from veupath_chatbot.services.parameter_optimization import (
result_to_json as _opt_result_to_json,
)
[docs]
class OptimizationToolsMixin:
"""Kani tool mixin for search parameter optimization."""
site_id: str = ""
async def _emit_event(self, event: JSONObject) -> None:
"""Emit an SSE event. Override in subclass to push to streaming queue."""
pass
[docs]
@ai_function()
async def optimize_search_parameters(
self,
record_type: Annotated[
str, AIParam(desc="WDK record type (e.g. 'transcript')")
],
search_name: Annotated[
str, AIParam(desc="WDK search/question urlSegment to optimise")
],
parameter_space_json: Annotated[
str,
AIParam(
desc=(
"JSON array of parameters to optimise. Each entry is an object: "
'{"name": "<paramName>", "type": "numeric"|"integer"|"categorical", '
'"min": <number>, "max": <number>, "logScale"?: bool, "step"?: <number>, "choices"?: ["a","b"]}. '
"Example: "
'[{"name":"fold_change","type":"numeric","min":1.5,"max":20}]'
)
),
],
fixed_parameters_json: Annotated[
str,
AIParam(
desc=(
"JSON object of parameters held constant during optimisation. "
'Example: {"organism":"P. falciparum 3D7","direction":"up-regulated"}'
)
),
],
controls_search_name: Annotated[
str,
AIParam(desc="Search that accepts a list of record IDs (for controls)"),
],
controls_param_name: Annotated[
str,
AIParam(desc="Parameter name within controls_search_name that accepts IDs"),
],
positive_controls: Annotated[
list[str] | None,
AIParam(desc="Known-positive IDs that should be returned"),
] = None,
negative_controls: Annotated[
list[str] | None,
AIParam(desc="Known-negative IDs that should NOT be returned"),
] = None,
budget: Annotated[
int, AIParam(desc="Max number of trials (default 15, max 50)")
] = 15,
objective: Annotated[
str,
AIParam(
desc=(
"Scoring objective: 'f1' (balanced, default), 'recall', "
"'precision', 'f_beta' (specify beta), or 'custom'"
)
),
] = "f1",
beta: Annotated[
float, AIParam(desc="Beta value for f_beta objective (default 1.0)")
] = 1.0,
method: Annotated[
str,
AIParam(
desc="Optimisation method: 'bayesian' (default, recommended), 'grid', or 'random'"
),
] = "bayesian",
controls_value_format: Annotated[
ControlValueFormat,
AIParam(desc="How to encode the control ID list"),
] = "newline",
controls_extra_parameters_json: Annotated[
str | None,
AIParam(
desc="JSON object of extra fixed parameters for the controls search"
),
] = None,
id_field: Annotated[
str | None,
AIParam(desc="Optional record-id field name for answer records"),
] = None,
result_count_penalty: Annotated[
float,
AIParam(
desc=(
"Weight for penalising large result sets (0 = off, 0.1 = tiebreaker, "
"higher = strongly prefer tighter results). Default 0.1."
)
),
] = 0.1,
) -> str:
"""Optimise search parameters against positive/negative control gene lists.
Runs multiple trials, varying the parameters in `parameter_space` while
holding `fixed_parameters` constant. Each trial evaluates the search
against the controls and scores the result. Returns the best
configuration, all trials, Pareto frontier, and sensitivity analysis.
This is a long-running operation. The user will see real-time progress
in the UI. Always confirm the plan with the user before calling this.
"""
def _err(msg: str) -> str:
return json.dumps({"error": msg})
# -- scalar argument validation ----------------------------------------
if not record_type or not record_type.strip():
return _err("record_type is required and must be a non-empty string.")
if not search_name or not search_name.strip():
return _err("search_name is required and must be a non-empty string.")
if not controls_search_name or not controls_search_name.strip():
return _err(
"controls_search_name is required and must be a non-empty string."
)
if not controls_param_name or not controls_param_name.strip():
return _err(
"controls_param_name is required and must be a non-empty string."
)
has_positives = positive_controls and len(positive_controls) > 0
has_negatives = negative_controls and len(negative_controls) > 0
if not has_positives and not has_negatives:
return _err(
"At least one of positive_controls or negative_controls must be "
"provided with at least one ID. Without any controls the optimiser "
"has no signal to score against."
)
_valid_objectives = ("f1", "f_beta", "recall", "precision", "custom")
if objective not in _valid_objectives:
return _err(
f"Invalid objective '{objective}'. "
f"Must be one of: {', '.join(repr(o) for o in _valid_objectives)}."
)
_valid_methods = ("bayesian", "grid", "random")
if method not in _valid_methods:
return _err(
f"Invalid method '{method}'. "
f"Must be one of: {', '.join(repr(m) for m in _valid_methods)}."
)
_valid_formats: tuple[str, ...] = ("newline", "json_list", "comma")
if controls_value_format not in _valid_formats:
return _err(
f"Invalid controls_value_format '{controls_value_format}'. "
f"Must be one of: {', '.join(repr(f) for f in _valid_formats)}."
)
if not isinstance(budget, int) or budget < 1:
return _err(f"budget must be a positive integer, got {budget!r}.")
if budget > 50:
return _err(
f"budget={budget} exceeds the maximum of 50. "
"Use a smaller budget or narrow the parameter space."
)
if objective == "f_beta" and (not isinstance(beta, (int, float)) or beta <= 0):
return _err(
f"beta must be a positive number when objective is 'f_beta', got {beta!r}."
)
# -- JSON argument parsing & validation --------------------------------
try:
raw_space = json.loads(parameter_space_json)
except (json.JSONDecodeError, TypeError) as exc:
return _err(f"parameter_space_json is not valid JSON: {exc}")
if not isinstance(raw_space, list):
return _err("parameter_space_json must be a JSON array.")
if len(raw_space) == 0:
return _err(
"parameter_space_json is an empty array. "
"Provide at least one parameter to optimise."
)
try:
fixed_parameters: JSONObject = (
json.loads(fixed_parameters_json) if fixed_parameters_json else {}
)
except (json.JSONDecodeError, TypeError) as exc:
return _err(f"fixed_parameters_json is not valid JSON: {exc}")
if not isinstance(fixed_parameters, dict):
return _err(
"fixed_parameters_json must be a JSON object (dict), "
f"got {type(fixed_parameters).__name__}."
)
controls_extra_parameters: JSONObject | None = None
if controls_extra_parameters_json:
try:
controls_extra_parameters = json.loads(controls_extra_parameters_json)
except (json.JSONDecodeError, TypeError) as exc:
return _err(f"controls_extra_parameters_json is not valid JSON: {exc}")
if not isinstance(controls_extra_parameters, dict):
return _err(
"controls_extra_parameters_json must be a JSON object (dict), "
f"got {type(controls_extra_parameters).__name__}."
)
# -- parameter_space entry validation ----------------------------------
_valid_param_types = ("numeric", "integer", "categorical")
specs: list[ParameterSpec] = []
seen_names: set[str] = set()
for i, p in enumerate(raw_space):
if not isinstance(p, dict):
return _err(
f"parameter_space[{i}] must be an object, got {type(p).__name__}."
)
pname = p.get("name")
if not pname or not isinstance(pname, str):
return _err(f"parameter_space[{i}] is missing a 'name' string field.")
if pname in seen_names:
return _err(
f"parameter_space[{i}]: duplicate parameter name '{pname}'. "
"Each parameter must have a unique name."
)
seen_names.add(pname)
ptype = p.get("type")
if ptype not in _valid_param_types:
return _err(
f"parameter_space[{i}] ('{pname}'): "
f"invalid type '{ptype}'. "
f"Must be one of: {', '.join(repr(t) for t in _valid_param_types)}."
)
if ptype in ("numeric", "integer"):
if "min" not in p or "max" not in p:
return _err(
f"parameter_space[{i}] ('{pname}'): "
f"type '{ptype}' requires both 'min' and 'max' fields."
)
try:
lo = float(p["min"])
hi = float(p["max"])
except TypeError, ValueError:
return _err(
f"parameter_space[{i}] ('{pname}'): "
f"'min' and 'max' must be numbers, "
f"got min={p['min']!r}, max={p['max']!r}."
)
if lo >= hi:
return _err(
f"parameter_space[{i}] ('{pname}'): "
f"'min' ({lo}) must be strictly less than 'max' ({hi})."
)
if "step" in p:
try:
step_val = float(p["step"])
except TypeError, ValueError:
return _err(
f"parameter_space[{i}] ('{pname}'): "
f"'step' must be a number, got {p['step']!r}."
)
if step_val <= 0:
return _err(
f"parameter_space[{i}] ('{pname}'): "
f"'step' must be positive, got {step_val}."
)
if ptype == "categorical":
choices_raw = p.get("choices")
if not isinstance(choices_raw, list) or len(choices_raw) == 0:
return _err(
f"parameter_space[{i}] ('{pname}'): "
f"type 'categorical' requires a non-empty 'choices' array."
)
specs.append(
ParameterSpec(
name=pname,
param_type=ptype,
min_value=float(p["min"]) if "min" in p else None,
max_value=float(p["max"]) if "max" in p else None,
log_scale=bool(p.get("logScale", False)),
step=float(p["step"]) if "step" in p else None,
choices=(
[str(c) for c in p["choices"]]
if "choices" in p and isinstance(p["choices"], list)
else None
),
)
)
config = OptimizationConfig(
budget=budget,
objective=cast(OptimizationObjective, objective),
beta=beta,
method=cast(OptimizationMethod, method),
result_count_penalty=max(0.0, result_count_penalty),
)
cancel_event = getattr(self, "_cancel_event", None)
result = await _run_optimization(
site_id=self.site_id,
record_type=record_type,
search_name=search_name,
fixed_parameters=fixed_parameters,
parameter_space=specs,
controls_search_name=controls_search_name,
controls_param_name=controls_param_name,
positive_controls=positive_controls,
negative_controls=negative_controls,
controls_value_format=controls_value_format,
controls_extra_parameters=controls_extra_parameters,
id_field=id_field,
config=config,
progress_callback=self._emit_event,
check_cancelled=(cancel_event.is_set if cancel_event is not None else None),
)
result_json = _opt_result_to_json(result)
# Auto-generate exports.
try:
from veupath_chatbot.services.export import get_export_service
svc = get_export_service()
name = f"{search_name}_optimization"
json_export = await svc.export_json(result_json, name)
result_json["downloads"] = {
"json": json_export.url,
"expiresInSeconds": json_export.expires_in_seconds,
}
except Exception:
pass
return json.dumps(result_json)