Source code for veupath_chatbot.transport.http.routers.experiments.cross_validation

"""Cross-validation endpoint for experiments."""

from typing import cast

from fastapi import APIRouter

from veupath_chatbot.platform.errors import WDKError
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.experiment.store import get_experiment_store
from veupath_chatbot.transport.http.deps import CurrentUser, ExperimentDep
from veupath_chatbot.transport.http.schemas.experiments import (
    RunCrossValidationRequest,
)

router = APIRouter()
logger = get_logger(__name__)


[docs] @router.post("/{experiment_id}/cross-validate") async def run_cv( exp: ExperimentDep, request: RunCrossValidationRequest, user_id: CurrentUser, ) -> JSONObject: """Run cross-validation on an existing experiment.""" from veupath_chatbot.services.experiment.cross_validation import ( run_cross_validation, ) from veupath_chatbot.services.experiment.types import to_json try: cv = await run_cross_validation( site_id=exp.config.site_id, record_type=exp.config.record_type, controls_search_name=exp.config.controls_search_name, controls_param_name=exp.config.controls_param_name, controls_value_format=exp.config.controls_value_format, positive_controls=exp.config.positive_controls, negative_controls=exp.config.negative_controls, tree=( exp.config.step_tree if isinstance(exp.config.step_tree, dict) else None ), search_name=exp.config.search_name if not exp.config.is_tree_mode else None, parameters=exp.config.parameters if not exp.config.is_tree_mode else None, k=request.k_folds, full_metrics=exp.metrics, ) except WDKError: raise except Exception as exc: logger.error( "Cross-validation failed", experiment_id=exp.id, error=str(exc), exc_info=True, ) raise WDKError(f"Cross-validation failed: {exc}") from exc exp.cross_validation = cv get_experiment_store().save(exp) return cast(JSONObject, to_json(cv))