"""Trial execution loop for parameter optimization."""
import asyncio
import time
from dataclasses import dataclass
from enum import Enum
import optuna
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject, JSONValue
from veupath_chatbot.services.control_tests import run_positive_negative_controls
from veupath_chatbot.services.experiment.types import ControlValueFormat
from veupath_chatbot.services.parameter_optimization.callbacks import (
emit_error,
emit_trial_progress,
)
from veupath_chatbot.services.parameter_optimization.config import (
CancelCheck,
OptimizationConfig,
OptimizationResult,
ParameterSpec,
ProgressCallback,
TrialResult,
)
from veupath_chatbot.services.parameter_optimization.scoring import (
_compute_pareto_frontier,
_compute_score,
_compute_sensitivity,
_to_float,
_to_int,
_trial_to_json,
)
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_MAX_CONSECUTIVE_FAILURES = 5
_PLATEAU_WINDOW = 10
_PERFECT_SCORE_THRESHOLD = 0.9999
_PARALLEL_CONCURRENCY = 4
_CACHE_PRECISION = 5
# ---------------------------------------------------------------------------
# Sampler creation
# ---------------------------------------------------------------------------
def _create_sampler(
cfg: OptimizationConfig,
parameter_space: list[ParameterSpec],
budget: int,
) -> tuple[optuna.samplers.BaseSampler, int]:
"""Create an Optuna sampler and (possibly adjusted) budget.
Grid search may reduce *budget* if the total number of combinations
is smaller than the requested budget.
"""
match cfg.method:
case "bayesian":
return optuna.samplers.TPESampler(seed=42), budget
case "grid":
grid: dict[str, list[float | int | str]] = {}
for p in parameter_space:
if p.param_type == "categorical" and p.choices:
grid[p.name] = list(p.choices)
elif p.param_type == "integer":
lo, hi = int(p.min_value or 0), int(p.max_value or 10)
st = int(p.step) if p.step else max(1, (hi - lo) // 10)
grid[p.name] = list(range(lo, hi + 1, st))
else: # numeric
lo_f, hi_f = p.min_value or 0.0, p.max_value or 1.0
n_levels = min(10, budget)
step_size = (hi_f - lo_f) / max(n_levels - 1, 1)
grid[p.name] = [lo_f + i * step_size for i in range(n_levels)]
total_combos: int = 1
for v in grid.values():
total_combos *= len(v)
return optuna.samplers.GridSampler(grid), int(min(total_combos, budget))
case "random":
return optuna.samplers.RandomSampler(seed=42), budget
case _:
return optuna.samplers.TPESampler(seed=42), budget
# ---------------------------------------------------------------------------
# Trial parameter suggestion
# ---------------------------------------------------------------------------
def _suggest_trial_params(
trial: optuna.trial.Trial,
parameter_space: list[ParameterSpec],
) -> dict[str, JSONValue]:
"""Suggest parameter values for one Optuna trial."""
params: dict[str, JSONValue] = {}
for spec in parameter_space:
if spec.param_type == "numeric":
params[spec.name] = trial.suggest_float(
spec.name,
spec.min_value or 0.0,
spec.max_value or 1.0,
log=spec.log_scale,
)
elif spec.param_type == "integer":
params[spec.name] = trial.suggest_int(
spec.name,
int(spec.min_value or 0),
int(spec.max_value or 100),
step=int(spec.step) if spec.step else 1,
)
elif spec.param_type == "categorical":
params[spec.name] = trial.suggest_categorical(
spec.name,
spec.choices or [""],
)
return params
# ---------------------------------------------------------------------------
# Trial context (shared immutable config + mutable accumulated state)
# ---------------------------------------------------------------------------
@dataclass
class _TrialContext:
"""Mutable state shared across the trial loop."""
site_id: str
record_type: str
search_name: str
fixed_parameters: dict[str, JSONValue]
parameter_space: list[ParameterSpec]
controls_search_name: str
controls_param_name: str
positive_controls: list[str] | None
negative_controls: list[str] | None
controls_value_format: ControlValueFormat
controls_extra_parameters: JSONObject | None
id_field: str | None
cfg: OptimizationConfig
optimization_id: str
budget: int
study: optuna.Study
progress_callback: ProgressCallback | None
check_cancelled: CancelCheck | None
start_time: float
# Accumulated state
trials: list[TrialResult]
best_trial: TrialResult | None = None
# ---------------------------------------------------------------------------
# WDK metrics extraction
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True, slots=True)
class TrialMetrics:
"""Intermediate metrics extracted from a WDK result."""
recall: float | None
fpr: float | None
result_count: int | None
positive_hits: int | None
negative_hits: int | None
def _extract_trial_metrics(wdk_result: JSONObject) -> TrialMetrics:
"""Extract recall, FPR, result count, and hit counts from a WDK result."""
target_data = wdk_result.get("target")
pos_data = wdk_result.get("positive")
neg_data = wdk_result.get("negative")
recall_val = pos_data.get("recall") if isinstance(pos_data, dict) else None
fpr_val = neg_data.get("falsePositiveRate") if isinstance(neg_data, dict) else None
result_count_val = (
target_data.get("resultCount") if isinstance(target_data, dict) else None
)
positive_hits_val = (
pos_data.get("intersectionCount") if isinstance(pos_data, dict) else None
)
negative_hits_val = (
neg_data.get("intersectionCount") if isinstance(neg_data, dict) else None
)
return TrialMetrics(
recall=_to_float(recall_val),
fpr=_to_float(fpr_val),
result_count=_to_int(result_count_val),
positive_hits=_to_int(positive_hits_val),
negative_hits=_to_int(negative_hits_val),
)
# ---------------------------------------------------------------------------
# Trial builders
# ---------------------------------------------------------------------------
def _build_failed_trial(
*,
trial_number: int,
params: dict[str, JSONValue],
n_positives: int,
n_negatives: int,
) -> TrialResult:
"""Create a TrialResult for a trial that failed (WDK error or exception)."""
return TrialResult(
trial_number=trial_number,
parameters=params,
score=0.0,
recall=None,
false_positive_rate=None,
result_count=None,
total_positives=n_positives,
total_negatives=n_negatives,
)
def _build_successful_trial(
*,
trial_number: int,
params: dict[str, JSONValue],
wdk_result: JSONObject,
cfg: OptimizationConfig,
n_positives: int,
n_negatives: int,
) -> TrialResult:
"""Create a TrialResult from a successful WDK evaluation."""
metrics = _extract_trial_metrics(wdk_result)
score = _compute_score(
metrics.recall,
metrics.fpr,
cfg,
result_count=metrics.result_count,
positive_hits=metrics.positive_hits,
negative_hits=metrics.negative_hits,
)
return TrialResult(
trial_number=trial_number,
parameters=params,
score=score,
recall=metrics.recall,
false_positive_rate=metrics.fpr,
result_count=metrics.result_count,
positive_hits=metrics.positive_hits,
negative_hits=metrics.negative_hits,
total_positives=n_positives,
total_negatives=n_negatives,
)
# ---------------------------------------------------------------------------
# Result aggregation
# ---------------------------------------------------------------------------
def _aggregate_results(
ctx: _TrialContext,
status: str,
error_message: str | None = None,
) -> OptimizationResult:
"""Build the final OptimizationResult from accumulated trial data."""
elapsed = time.monotonic() - ctx.start_time
return OptimizationResult(
optimization_id=ctx.optimization_id,
best_trial=ctx.best_trial,
all_trials=ctx.trials,
pareto_frontier=_compute_pareto_frontier(ctx.trials),
sensitivity=_compute_sensitivity(ctx.parameter_space, ctx.study),
total_time_seconds=elapsed,
status=status,
error_message=error_message,
)
# ---------------------------------------------------------------------------
# Progress emission
# ---------------------------------------------------------------------------
async def _emit_trial_result(
ctx: _TrialContext,
*,
trial_num: int,
trial_result: TrialResult,
wdk_error: str = "",
) -> None:
"""Emit a progress event for a single trial (failed or successful)."""
if not ctx.progress_callback:
return
trial_json = _trial_to_json(trial_result)
if wdk_error:
trial_json = {**trial_json, "error": wdk_error}
await emit_trial_progress(
ctx.progress_callback,
optimization_id=ctx.optimization_id,
trial_num=trial_num,
budget=ctx.budget,
trial_json=trial_json,
best_trial=ctx.best_trial,
recent_trials=ctx.trials[-5:],
)
# ---------------------------------------------------------------------------
# WDK evaluation (cached, semaphore-guarded)
# ---------------------------------------------------------------------------
def _cache_key(params: dict[str, JSONValue]) -> tuple[tuple[str, str], ...]:
"""Build a hashable key from optimised params (rounded floats)."""
items: list[tuple[str, str]] = []
for k in sorted(params):
v = params[k]
if isinstance(v, float):
items.append((k, str(round(v, _CACHE_PRECISION))))
else:
items.append((k, str(v)))
return tuple(items)
_EvalCache = dict[tuple[tuple[str, str], ...], tuple[JSONObject | None, str]]
async def _evaluate_trial(
ctx: _TrialContext,
trial_params: JSONObject,
optimised_params: dict[str, JSONValue],
sem: asyncio.Semaphore,
cache: _EvalCache,
) -> tuple[JSONObject | None, str]:
"""Run WDK evaluation for a single trial (semaphore-guarded + cached).
Returns (wdk_result_or_None, error_string).
"""
key = _cache_key(optimised_params)
cached = cache.get(key)
if cached is not None:
logger.debug("Cache hit for params", key=key)
return cached
async with sem:
wdk_error = ""
wdk_result: JSONObject | None = None
try:
wdk_result = await run_positive_negative_controls(
site_id=ctx.site_id,
record_type=ctx.record_type,
target_search_name=ctx.search_name,
target_parameters=trial_params,
controls_search_name=ctx.controls_search_name,
controls_param_name=ctx.controls_param_name,
positive_controls=ctx.positive_controls,
negative_controls=ctx.negative_controls,
controls_value_format=ctx.controls_value_format,
controls_extra_parameters=ctx.controls_extra_parameters,
id_field=ctx.id_field,
)
except Exception as trial_exc:
wdk_error = str(trial_exc)
wdk_result = None
if wdk_result is not None and isinstance(wdk_result.get("error"), str):
wdk_error = str(wdk_result["error"])
wdk_result = None
result_pair = (wdk_result, wdk_error)
cache[key] = result_pair
return result_pair
# ---------------------------------------------------------------------------
# Early-stop checks
# ---------------------------------------------------------------------------
[docs]
class EarlyStopReason(Enum):
"""Why the optimisation loop stopped early."""
PERFECT_SCORE = "perfect_score"
PLATEAU = "plateau"
def _check_early_stop(
*,
best_trial: TrialResult | None,
trials_since_improvement: int,
plateau_window: int = _PLATEAU_WINDOW,
perfect_score_threshold: float = _PERFECT_SCORE_THRESHOLD,
) -> EarlyStopReason | None:
"""Pure early-stop check (no side effects, no logging).
Returns the reason for stopping, or None to continue.
"""
if best_trial and best_trial.score >= perfect_score_threshold:
return EarlyStopReason.PERFECT_SCORE
if trials_since_improvement >= plateau_window:
return EarlyStopReason.PLATEAU
return None
def _should_early_stop(
ctx: _TrialContext,
trials_since_improvement: int,
trial_num: int,
) -> bool:
"""Check whether the loop should stop early (with logging)."""
reason = _check_early_stop(
best_trial=ctx.best_trial,
trials_since_improvement=trials_since_improvement,
)
if reason is None:
return False
match reason:
case EarlyStopReason.PERFECT_SCORE:
logger.info(
"Early stop: perfect score reached",
optimization_id=ctx.optimization_id,
score=ctx.best_trial.score if ctx.best_trial else 0,
trial=trial_num,
)
case EarlyStopReason.PLATEAU:
logger.info(
"Early stop: score plateau detected",
optimization_id=ctx.optimization_id,
best_score=ctx.best_trial.score if ctx.best_trial else 0,
trials_without_improvement=trials_since_improvement,
trial=trial_num,
)
return True
def _should_abort_on_failures(
ctx: _TrialContext,
consecutive_failures: int,
wdk_error: str,
) -> str | None:
"""Return an error message if the loop should abort due to failures.
Returns None if the loop should continue.
"""
if consecutive_failures >= _MAX_CONSECUTIVE_FAILURES and ctx.best_trial is None:
msg = (
f"Aborted after {consecutive_failures} consecutive failures. "
f"Last error: {wdk_error}"
)
logger.error(
"Aborting optimisation: all trials failed",
optimization_id=ctx.optimization_id,
consecutive_failures=consecutive_failures,
last_error=wdk_error,
)
return msg
return None
# ---------------------------------------------------------------------------
# Single-trial processing
# ---------------------------------------------------------------------------
@dataclass(frozen=True, slots=True)
class _TrialOutcome:
"""Result of processing a single trial within the batch."""
trial_result: TrialResult
wdk_error: str
grid_exhausted: bool
is_failure: bool
def _unpack_gather_result(
raw_result: tuple[JSONObject | None, str] | BaseException,
trial_num: int,
params: dict[str, JSONValue],
) -> tuple[JSONObject | None, str]:
"""Unpack a result from asyncio.gather (may be an exception)."""
if isinstance(raw_result, BaseException):
wdk_error = str(raw_result)
logger.warning(
"WDK evaluation failed for trial",
trial=trial_num,
params=params,
error=wdk_error,
)
return None, wdk_error
wdk_result, wdk_error = raw_result
if wdk_error:
logger.warning(
"WDK evaluation failed for trial",
trial=trial_num,
params=params,
error=wdk_error,
)
return wdk_result, wdk_error
def _process_single_trial(
*,
ctx: _TrialContext,
ot: optuna.trial.Trial,
params: dict[str, JSONValue],
wdk_result: JSONObject | None,
wdk_error: str,
trial_num: int,
n_positives: int,
n_negatives: int,
) -> _TrialOutcome:
"""Process a single trial result: build TrialResult and report to Optuna."""
if wdk_result is None:
ctx.study.tell(ot, state=optuna.trial.TrialState.FAIL)
return _TrialOutcome(
trial_result=_build_failed_trial(
trial_number=trial_num,
params=params,
n_positives=n_positives,
n_negatives=n_negatives,
),
wdk_error=wdk_error,
grid_exhausted=False,
is_failure=True,
)
trial_result = _build_successful_trial(
trial_number=trial_num,
params=params,
wdk_result=wdk_result,
cfg=ctx.cfg,
n_positives=n_positives,
n_negatives=n_negatives,
)
grid_exhausted = False
try:
ctx.study.tell(ot, trial_result.score)
except RuntimeError as tell_exc:
if "stop" in str(tell_exc).lower():
logger.info(
"Grid sampler exhausted search space",
optimization_id=ctx.optimization_id,
trial=trial_num,
)
grid_exhausted = True
else:
raise
return _TrialOutcome(
trial_result=trial_result,
wdk_error=wdk_error,
grid_exhausted=grid_exhausted,
is_failure=False,
)
# ---------------------------------------------------------------------------
# Batch processing
# ---------------------------------------------------------------------------
@dataclass
class _LoopState:
"""Mutable counters for the trial loop."""
consecutive_failures: int = 0
trials_since_improvement: int = 0
abort_result: OptimizationResult | None = None
async def _handle_failed_outcome(
ctx: _TrialContext,
state: _LoopState,
outcome: _TrialOutcome,
trial_num: int,
) -> bool:
"""Handle a failed trial outcome. Returns True if the loop should abort."""
state.consecutive_failures += 1
await _emit_trial_result(
ctx,
trial_num=trial_num,
trial_result=outcome.trial_result,
wdk_error=outcome.wdk_error,
)
abort_msg = _should_abort_on_failures(
ctx,
state.consecutive_failures,
outcome.wdk_error,
)
if abort_msg:
if ctx.progress_callback:
await emit_error(
ctx.progress_callback,
optimization_id=ctx.optimization_id,
error=abort_msg,
)
state.abort_result = _aggregate_results(ctx, "error", abort_msg)
return True
return False
async def _handle_successful_outcome(
ctx: _TrialContext,
state: _LoopState,
outcome: _TrialOutcome,
trial_num: int,
) -> bool:
"""Handle a successful trial outcome. Returns True if the loop should stop."""
state.consecutive_failures = 0
if ctx.best_trial is None or outcome.trial_result.score > ctx.best_trial.score:
ctx.best_trial = outcome.trial_result
state.trials_since_improvement = 0
else:
state.trials_since_improvement += 1
await _emit_trial_result(
ctx,
trial_num=trial_num,
trial_result=outcome.trial_result,
)
return outcome.grid_exhausted or _should_early_stop(
ctx,
state.trials_since_improvement,
trial_num,
)
async def _process_batch(
ctx: _TrialContext,
state: _LoopState,
optuna_trials: list[optuna.trial.Trial],
batch_params: list[dict[str, JSONValue]],
wdk_results: list[tuple[JSONObject | None, str] | BaseException],
trial_idx: int,
n_positives: int,
n_negatives: int,
) -> bool:
"""Process all results in a batch. Returns True if the loop should stop."""
for i, raw_result in enumerate(wdk_results):
trial_num = trial_idx + i + 1
wdk_result, wdk_error = _unpack_gather_result(
raw_result,
trial_num,
batch_params[i],
)
outcome = _process_single_trial(
ctx=ctx,
ot=optuna_trials[i],
params=batch_params[i],
wdk_result=wdk_result,
wdk_error=wdk_error,
trial_num=trial_num,
n_positives=n_positives,
n_negatives=n_negatives,
)
ctx.trials.append(outcome.trial_result)
if outcome.is_failure:
if await _handle_failed_outcome(ctx, state, outcome, trial_num):
return True
continue
if await _handle_successful_outcome(ctx, state, outcome, trial_num):
return True
return False
# ---------------------------------------------------------------------------
# Main trial loop
# ---------------------------------------------------------------------------
[docs]
async def run_trial_loop(ctx: _TrialContext) -> OptimizationResult:
"""Execute the full trial loop and return an OptimizationResult."""
n_positives = len(ctx.positive_controls or [])
n_negatives = len(ctx.negative_controls or [])
state = _LoopState()
sem = asyncio.Semaphore(_PARALLEL_CONCURRENCY)
eval_cache: _EvalCache = {}
clean_fixed = {k: v for k, v in ctx.fixed_parameters.items() if v not in ("", None)}
try:
trial_idx = 0
while trial_idx < ctx.budget:
if ctx.check_cancelled and ctx.check_cancelled():
logger.info(
"Optimization cancelled by user",
optimization_id=ctx.optimization_id,
completed_trials=len(ctx.trials),
)
break
batch_size = min(_PARALLEL_CONCURRENCY, ctx.budget - trial_idx)
optuna_trials = [ctx.study.ask() for _ in range(batch_size)]
batch_params = [
_suggest_trial_params(ot, ctx.parameter_space) for ot in optuna_trials
]
full_params = [{**clean_fixed, **p} for p in batch_params]
wdk_results = await asyncio.gather(
*(
_evaluate_trial(ctx, fp, bp, sem, eval_cache)
for fp, bp in zip(full_params, batch_params, strict=False)
),
return_exceptions=True,
)
should_stop = await _process_batch(
ctx,
state,
optuna_trials,
batch_params,
wdk_results,
trial_idx,
n_positives,
n_negatives,
)
if should_stop:
if state.abort_result:
return state.abort_result
break
trial_idx += batch_size
except Exception as exc:
logger.error("Optimization failed", error=str(exc), exc_info=True)
if ctx.progress_callback:
await emit_error(
ctx.progress_callback,
optimization_id=ctx.optimization_id,
error=str(exc),
)
return _aggregate_results(ctx, "error", str(exc))
was_cancelled = ctx.check_cancelled() if ctx.check_cancelled else False
status = "cancelled" if was_cancelled else "completed"
return _aggregate_results(ctx, status)