"""Phase 4: Parameter sensitivity -- sweep numeric params across their range."""
import asyncio
import copy
from typing import TypedDict
from veupath_chatbot.domain.strategy.tree import collect_dict_leaves, walk_dict_tree
from veupath_chatbot.integrations.veupathdb.factory import get_strategy_api
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.experiment.helpers import ProgressCallback, safe_float
from veupath_chatbot.services.experiment.step_analysis._evaluation import (
_extract_eval_counts,
_f1_from_counts,
run_controls_against_tree,
)
from veupath_chatbot.services.experiment.step_analysis._tree_utils import (
_node_id,
)
from veupath_chatbot.services.experiment.types import (
ControlValueFormat,
ParameterSensitivity,
ParameterSweepPoint,
to_json,
)
SENSITIVITY_SWEEP_POINTS = 5
logger = get_logger(__name__)
class _NumericParamSpec(TypedDict):
"""Type-safe dict for a discovered numeric parameter."""
name: str
min: float
max: float
current: float
def _safe_float(v: object) -> float | None:
"""Convert to float, returning ``None`` for missing/unparseable values.
Delegates to :func:`safe_float` for the actual conversion (including
``inf``/``nan`` rejection) but preserves ``None`` semantics for callers
that need to distinguish "missing" from zero.
"""
if v is None:
return None
# Use a sentinel that safe_float cannot produce from valid input.
# safe_float rejects inf/nan, so inf is safe as a "not converted" marker.
sentinel = float("inf")
result = safe_float(v, default=sentinel)
if result == sentinel:
return None
return result
async def _discover_numeric_params(
site_id: str,
record_type: str,
leaf: JSONObject,
) -> list[_NumericParamSpec]:
"""Discover numeric parameters on a leaf from WDK metadata."""
search_name = str(leaf.get("searchName", ""))
if not search_name:
return []
api = get_strategy_api(site_id)
try:
details = await api.client.get_search_details(record_type, search_name)
except Exception as exc:
logger.warning(
"Failed to fetch search details for numeric param discovery",
search_name=search_name,
record_type=record_type,
error=str(exc),
)
return []
search_data = details.get("searchData") if isinstance(details, dict) else None
if not isinstance(search_data, dict):
return []
params = search_data.get("parameters")
if not isinstance(params, list):
return []
result: list[_NumericParamSpec] = []
node_params = leaf.get("parameters", {})
if not isinstance(node_params, dict):
node_params = {}
for p in params:
if not isinstance(p, dict):
continue
pname = str(p.get("name", ""))
ptype = str(p.get("type", ""))
if ptype not in ("number", "string"):
continue
is_number = p.get("isNumber") is True or ptype == "number"
if not is_number:
continue
min_val = _safe_float(p.get("min"))
max_val = _safe_float(p.get("max"))
current = _safe_float(node_params.get(pname))
initial = _safe_float(p.get("initialDisplayValue"))
ref = current if current is not None else initial
if min_val is None:
min_val = (
0.0 if (ref is not None and ref >= 0) else (ref * 10 if ref else 0.0)
)
if max_val is None:
max_val = ref * 10 if (ref is not None and ref > 0) else 100.0
if min_val >= max_val:
max_val = min_val + 1.0
result.append(
{
"name": pname,
"min": min_val,
"max": max_val,
"current": ref if ref is not None else (min_val + max_val) / 2,
}
)
return result
def _generate_sweep_values(
min_val: float,
max_val: float,
current: float,
n: int = SENSITIVITY_SWEEP_POINTS,
) -> list[float]:
"""Generate sweep values including the current value and endpoints."""
step = (max_val - min_val) / (n - 1) if n > 1 else 0
values = [min_val + i * step for i in range(n)]
if current not in values:
values.append(current)
values.sort()
return values
def _find_bound_partner(
pname: str, all_param_specs: list[_NumericParamSpec]
) -> _NumericParamSpec | None:
"""Find the paired lower/upper bound parameter for a given param.
Detects naming patterns:
- ``foo_lower`` <-> ``foo_upper``
- ``foo_min`` <-> ``foo_max``
- ``MinFoo`` <-> ``MaxFoo``
- ``min_foo`` <-> ``max_foo``
"""
name_lower = pname.lower()
for suffix_a, suffix_b in [
("_lower", "_upper"),
("_upper", "_lower"),
("_min", "_max"),
("_max", "_min"),
]:
if name_lower.endswith(suffix_a):
stem = pname[: len(pname) - len(suffix_a)]
candidate = stem + suffix_b
# Try original casing first, then lowercase match
for s in all_param_specs:
sn = str(s["name"])
if sn == candidate or sn.lower() == candidate.lower():
return s
break
for prefix_a, prefix_b in [
("min", "max"),
("max", "min"),
("Min", "Max"),
("Max", "Min"),
]:
if pname.startswith(prefix_a) and len(pname) > len(prefix_a):
rest = pname[len(prefix_a) :]
candidate = prefix_b + rest
for s in all_param_specs:
if str(s["name"]) == candidate:
return s
break
return None
def _is_lower_bound(pname: str) -> bool:
"""Determine if a parameter represents a lower bound."""
name_l = pname.lower()
return (
name_l.endswith("_lower") or name_l.endswith("_min") or name_l.startswith("min")
)
def _constrain_sweep_range(
pname: str,
min_val: float,
max_val: float,
partner_current: float | None,
) -> tuple[float, float]:
"""Constrain sweep range for a bound parameter based on its partner's value.
For lower-bound params: sweep from min_val to min(max_val, partner_current).
For upper-bound params: sweep from max(min_val, partner_current) to max_val.
"""
if partner_current is None:
return min_val, max_val
if _is_lower_bound(pname):
effective_max = min(max_val, partner_current)
if effective_max <= min_val:
effective_max = partner_current
return min_val, effective_max
else:
effective_min = max(min_val, partner_current)
if effective_min >= max_val:
effective_min = partner_current
return effective_min, max_val
_MINIMUM_F1_IMPROVEMENT = 0.0
[docs]
async def sweep_parameters(
*,
site_id: str,
record_type: str,
tree: JSONObject,
controls_search_name: str,
controls_param_name: str,
controls_value_format: ControlValueFormat,
positive_controls: list[str],
negative_controls: list[str],
progress_callback: ProgressCallback | None = None,
) -> list[ParameterSensitivity]:
"""Sweep numeric params on each leaf across their WDK-declared range.
Respects paired min/max bound parameters, deduplicates identical
searches across leaves, and only recommends changes when the
improvement is meaningful.
:param tree: ``PlanStepNode``-shaped dict.
:returns: One :class:`ParameterSensitivity` per numeric param.
"""
leaves = collect_dict_leaves(tree)
if not leaves:
return []
# Deduplicate: only sweep unique (searchName, param) combinations.
# When multiple leaves share the same search, sweep only the first one
# since they would produce identical results.
seen_search_params: set[str] = set()
all_specs: list[tuple[JSONObject, _NumericParamSpec, list[_NumericParamSpec]]] = []
for leaf in leaves:
search_name = str(leaf.get("searchName", ""))
params = await _discover_numeric_params(site_id, record_type, leaf)
for spec in params:
dedup_key = f"{search_name}:{spec['name']}"
if dedup_key in seen_search_params:
logger.debug(
"Skipping duplicate sweep",
search=search_name,
param=spec["name"],
)
continue
seen_search_params.add(dedup_key)
all_specs.append((leaf, spec, params))
if not all_specs:
return []
results: list[ParameterSensitivity] = []
sem = asyncio.Semaphore(3)
total_params = len(all_specs)
for pi, (leaf, spec, leaf_all_params) in enumerate(all_specs):
lid = _node_id(leaf)
pname = str(spec["name"])
min_val = spec["min"]
max_val = spec["max"]
current = spec["current"]
# Constrain sweep range for paired min/max bound parameters
partner_spec = _find_bound_partner(pname, leaf_all_params)
partner_current: float | None = None
if partner_spec is not None:
partner_current = partner_spec["current"]
min_val, max_val = _constrain_sweep_range(
pname, min_val, max_val, partner_current
)
if min_val >= max_val:
max_val = min_val + 1.0
# Clamp current into the constrained range so the sweep is sensible
current = max(min_val, min(max_val, current))
sweep_values = _generate_sweep_values(min_val, max_val, current)
if progress_callback:
constraint_note = ""
if partner_spec is not None:
partner_name = str(partner_spec["name"])
constraint_note = f" (bounded by {partner_name}={partner_current})"
await progress_callback(
{
"type": "step_analysis_progress",
"data": {
"phase": "sensitivity",
"message": (
f"Sweeping {pname} on {lid} "
f"({pi + 1}/{total_params}, {len(sweep_values)} points)"
f"{constraint_note}"
),
"current": pi + 1,
"total": total_params,
},
}
)
sweep_points: list[ParameterSweepPoint] = []
async def _eval_value(
val: float,
_lid: str = lid,
_pname: str = pname,
) -> ParameterSweepPoint | None:
modified = copy.deepcopy(tree)
def _patch_node(node: JSONObject) -> None:
if _node_id(node) == _lid:
params = node.get("parameters")
if not isinstance(params, dict):
params = {}
node["parameters"] = params
params[_pname] = str(val)
walk_dict_tree(modified, _patch_node)
try:
async with sem:
raw = await run_controls_against_tree(
site_id=site_id,
record_type=record_type,
tree=modified,
controls_search_name=controls_search_name,
controls_param_name=controls_param_name,
controls_value_format=controls_value_format,
positive_controls=positive_controls,
negative_controls=negative_controls,
)
except Exception as exc:
logger.warning(
"Sensitivity sweep point failed",
step=_lid,
param=_pname,
value=val,
error=str(exc),
)
return None
counts = _extract_eval_counts(raw)
recall = counts.pos_hits / counts.pos_total if counts.pos_total > 0 else 0.0
fpr = counts.neg_hits / counts.neg_total if counts.neg_total > 0 else 0.0
return ParameterSweepPoint(
value=round(val, 6),
positive_hits=counts.pos_hits,
negative_hits=counts.neg_hits,
total_results=counts.total_results,
recall=recall,
fpr=fpr,
f1=_f1_from_counts(counts),
)
tasks = [_eval_value(v) for v in sweep_values]
points = await asyncio.gather(*tasks)
sweep_points = [p for p in points if p is not None]
sweep_points.sort(key=lambda p: p.value)
# Pick the best point, but only recommend a change if the F1
# improvement is meaningful (> threshold). This prevents noisy
# recommendations when multiple values score similarly.
cur_point = (
min(sweep_points, key=lambda p: abs(p.value - current))
if sweep_points
else None
)
cur_f1 = cur_point.f1 if cur_point else 0.0
best_point = max(sweep_points, key=lambda p: p.f1) if sweep_points else None
improvement = (best_point.f1 - cur_f1) if best_point else 0.0
if best_point and improvement > _MINIMUM_F1_IMPROVEMENT:
recommended_value = best_point.value
else:
recommended_value = current
best_point = cur_point
# For bound params, validate the recommendation doesn't violate its partner
if (
partner_current is not None
and best_point
and (
(_is_lower_bound(pname) and recommended_value > partner_current)
or (not _is_lower_bound(pname) and recommended_value < partner_current)
)
):
recommended_value = current
recommendation = ""
if best_point and abs(recommended_value - current) > 1e-6 and cur_point:
recommendation = (
f"Changing {pname} from {current:.4g} to {recommended_value:.4g} "
f"changes recall {cur_point.recall:.0%} -> {best_point.recall:.0%}"
f" and FPR {cur_point.fpr:.0%} -> {best_point.fpr:.0%}"
)
ps = ParameterSensitivity(
step_id=lid,
param_name=pname,
current_value=current,
sweep_points=sweep_points,
recommended_value=recommended_value,
recommendation=recommendation,
)
results.append(ps)
if progress_callback:
await progress_callback(
{
"type": "step_analysis_progress",
"data": {
"phase": "sensitivity",
"message": recommendation
or f"Parameter {pname}: no change recommended",
"current": pi + 1,
"total": total_params,
"parameterSensitivity": to_json(ps),
},
}
)
logger.info("Parameter sensitivity complete", count=len(results))
return results