"""Multi-step tree-knob optimization.
Tunes threshold parameters and boolean operators across a strategy tree
using Optuna, optimizing for rank-based objectives (Precision@K,
Enrichment@K) with optional list-size constraints.
"""
import copy
import time
from veupath_chatbot.domain.strategy.tree import walk_dict_tree
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.experiment.helpers import safe_int
from veupath_chatbot.services.experiment.metrics import (
compute_confusion_matrix,
compute_metrics,
)
from veupath_chatbot.services.experiment.types import (
ControlValueFormat,
ExperimentMetrics,
OperatorKnob,
ThresholdKnob,
TreeOptimizationResult,
TreeOptimizationTrial,
)
logger = get_logger(__name__)
[docs]
async def optimize_tree_knobs(
*,
site_id: str,
record_type: str,
base_tree: JSONObject,
threshold_knobs: list[ThresholdKnob],
operator_knobs: list[OperatorKnob],
positive_controls: list[str],
negative_controls: list[str],
controls_search_name: str,
controls_param_name: str,
controls_value_format: ControlValueFormat,
objective: str = "precision_at_50",
budget: int = 50,
max_list_size: int | None = None,
) -> TreeOptimizationResult:
"""Run Optuna optimization over tree knobs.
:param base_tree: ``PlanStepNode``-shaped dict (the template tree).
:param threshold_knobs: Numeric parameter knobs on leaf steps.
:param operator_knobs: Boolean operator knobs on combine nodes.
:param objective: Target metric name (e.g. ``precision_at_50``).
:param budget: Maximum number of Optuna trials.
:param max_list_size: Optional upper bound on result list size.
:returns: Optimization result with best trial and history.
"""
try:
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
except ImportError:
logger.error("Optuna not installed — tree optimization unavailable")
return TreeOptimizationResult(objective=objective)
from veupath_chatbot.services.experiment.step_analysis import (
run_controls_against_tree,
)
start = time.monotonic()
all_trials: list[TreeOptimizationTrial] = []
best_trial: TreeOptimizationTrial | None = None
def _apply_knobs(
tree: JSONObject,
threshold_vals: dict[str, float],
operator_vals: dict[str, str],
) -> JSONObject:
"""Return a copy of the tree with knob values applied."""
modified = copy.deepcopy(tree)
_apply_knobs_recursive(modified, threshold_vals, operator_vals)
return modified
study = optuna.create_study(
direction="maximize",
sampler=optuna.samplers.TPESampler(seed=42),
)
async def _objective(trial: optuna.Trial) -> float:
nonlocal best_trial
threshold_vals: dict[str, float] = {}
for knob in threshold_knobs:
key = f"{knob.step_id}:{knob.param_name}"
val = trial.suggest_float(
key,
knob.min_val,
knob.max_val,
step=knob.step_size or None,
)
threshold_vals[key] = val
operator_vals: dict[str, str] = {}
for op_knob in operator_knobs:
op_val = trial.suggest_categorical(
op_knob.combine_node_id,
op_knob.options,
)
operator_vals[op_knob.combine_node_id] = str(op_val)
modified_tree = _apply_knobs(base_tree, threshold_vals, operator_vals)
result = await run_controls_against_tree(
site_id=site_id,
record_type=record_type,
tree=modified_tree,
controls_search_name=controls_search_name,
controls_param_name=controls_param_name,
controls_value_format=controls_value_format,
positive_controls=positive_controls,
negative_controls=negative_controls,
)
target = result.get("target", {})
total_results = (
safe_int(target.get("resultCount", 0)) if isinstance(target, dict) else 0
)
if max_list_size is not None and total_results > max_list_size:
return -1.0
pos = result.get("positive", {})
neg = result.get("negative", {})
pos_hits = (
safe_int(pos.get("intersectionCount", 0)) if isinstance(pos, dict) else 0
)
neg_hits = (
safe_int(neg.get("intersectionCount", 0)) if isinstance(neg, dict) else 0
)
total_pos = len(positive_controls)
total_neg = len(negative_controls)
cm = compute_confusion_matrix(
positive_hits=pos_hits,
total_positives=total_pos,
negative_hits=neg_hits,
total_negatives=total_neg,
)
m = compute_metrics(cm, total_results=total_results)
random_prec = total_pos / total_results if total_results > 0 else 0.0
enrichment = m.precision / random_prec if random_prec > 0 else 0.0
score = _select_metric(objective, metrics=m, enrichment=enrichment)
params: dict[str, float | str] = {**threshold_vals, **operator_vals}
t = TreeOptimizationTrial(
trial_number=trial.number + 1,
parameters=params,
score=score,
list_size=total_results,
)
all_trials.append(t)
if best_trial is None or score > best_trial.score:
best_trial = t
return score
for _i in range(budget):
trial = study.ask()
score = await _objective(trial)
study.tell(trial, score)
elapsed = time.monotonic() - start
return TreeOptimizationResult(
best_trial=best_trial,
all_trials=all_trials,
total_time_seconds=elapsed,
objective=objective,
)
def _apply_knobs_recursive(
node: JSONObject,
threshold_vals: dict[str, float],
operator_vals: dict[str, str],
) -> None:
"""Recursively apply knob values to a tree in-place."""
def _apply(n: JSONObject) -> None:
nid = str(n.get("id", ""))
if nid in operator_vals:
n["operator"] = operator_vals[nid]
raw_params = n.get("parameters")
if isinstance(raw_params, dict):
for key, val in threshold_vals.items():
step_id, param_name = key.split(":", 1)
if nid == step_id:
raw_params[param_name] = str(val)
walk_dict_tree(node, _apply)
def _select_metric(
objective: str,
*,
metrics: ExperimentMetrics,
enrichment: float,
) -> float:
"""Select the metric value based on the objective name.
Supports: ``precision_at_K``, ``recall_at_K``, ``enrichment_at_K``,
``f1``, ``mcc``, ``sensitivity``, ``specificity``, ``balanced_accuracy``.
"""
obj = objective.lower()
if "precision" in obj:
return metrics.precision
if obj in ("sensitivity", "recall") or obj.startswith("recall"):
return metrics.sensitivity
if "enrichment" in obj:
return enrichment
if "specificity" in obj:
return metrics.specificity
if "balanced_accuracy" in obj:
return metrics.balanced_accuracy
if "mcc" in obj:
return metrics.mcc
if "f1" in obj:
return metrics.f1_score
return metrics.precision