Source code for veupath_chatbot.services.experiment.core.streaming

"""Background task launchers for experiment execution — CQRS version.

Events are persisted to Redis Streams. Operations are registered in PostgreSQL.
"""

import asyncio
import copy
import json
from uuid import UUID, uuid4

from veupath_chatbot.persistence.repositories.stream import StreamRepository
from veupath_chatbot.persistence.repositories.user import UserRepository
from veupath_chatbot.persistence.session import async_session_factory
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.redis import get_redis
from veupath_chatbot.platform.tasks import spawn
from veupath_chatbot.platform.types import JSONObject
from veupath_chatbot.services.experiment.helpers import ProgressCallback
from veupath_chatbot.services.experiment.service import run_experiment
from veupath_chatbot.services.experiment.types import (
    BatchExperimentConfig,
    Experiment,
    ExperimentConfig,
    experiment_to_json,
)

logger = get_logger(__name__)

# Synthetic stream for experiment operations (experiments don't have
# a conversation stream — they share a fixed UUID).
_EXPERIMENT_STREAM_ID = UUID("00000000-0000-0000-0000-000000000001")
_EXPERIMENT_USER_ID = UUID("00000000-0000-0000-0000-000000000000")


async def _emit_to_redis(
    operation_id: str,
    event_type: str,
    event_data: JSONObject,
) -> None:
    """Emit a single event to the experiment Redis stream."""
    redis = get_redis()
    await redis.xadd(
        f"op:{operation_id}",
        {
            "op": operation_id.encode(),
            "type": event_type.encode(),
            "data": json.dumps(event_data, default=str).encode(),
        },
    )


def _make_progress_callback(
    operation_id: str,
) -> ProgressCallback:
    """Build a progress callback that emits events to a Redis stream."""

    async def _cb(evt: JSONObject) -> None:
        raw_type = evt.get("type", "experiment_progress")
        event_type = raw_type if isinstance(raw_type, str) else "experiment_progress"
        raw_data = evt.get("data")
        event_data = raw_data if isinstance(raw_data, dict) else evt
        await _emit_to_redis(operation_id, event_type, event_data)

    return _cb


async def _finalize_operation(operation_id: str, *, failed: bool) -> None:
    """Mark an experiment operation as completed or failed."""
    async with async_session_factory() as session:
        repo = StreamRepository(session)
        if failed:
            await repo.fail_operation(operation_id)
        else:
            await repo.complete_operation(operation_id)
        await session.commit()


async def _register_experiment_operation(operation_id: str, op_type: str) -> None:
    """Register an experiment operation in the DB before spawning.

    Ensures the experiment stream exists (creates it on first use) and
    registers the operation so the subscribe endpoint can find it
    immediately.
    """
    async with async_session_factory() as session:
        repo = StreamRepository(session)
        stream = await repo.get_by_id(_EXPERIMENT_STREAM_ID)
        if not stream:
            user_repo = UserRepository(session)
            await user_repo.get_or_create(_EXPERIMENT_USER_ID)
            await repo.create(
                user_id=_EXPERIMENT_USER_ID,
                site_id="system",
                stream_id=_EXPERIMENT_STREAM_ID,
                name="Experiments",
            )
        await repo.register_operation(operation_id, _EXPERIMENT_STREAM_ID, op_type)
        await session.commit()


[docs] async def start_experiment( config: ExperimentConfig, *, user_id: str | None = None ) -> str: """Launch a single experiment as a background task. Returns operation ID.""" operation_id = f"op_{uuid4().hex[:12]}" await _register_experiment_operation(operation_id, "experiment") async def _run() -> None: failed = False try: result = await run_experiment( config, user_id=user_id, progress_callback=_make_progress_callback(operation_id), ) await _emit_to_redis( operation_id, "experiment_complete", experiment_to_json(result), ) except Exception as exc: failed = True logger.error("Experiment failed", error=str(exc), exc_info=True) await _emit_to_redis(operation_id, "experiment_error", {"error": str(exc)}) finally: await _emit_to_redis(operation_id, "experiment_end", {}) await _finalize_operation(operation_id, failed=failed) spawn(_run()) return operation_id
[docs] async def start_batch_experiment( batch_config: BatchExperimentConfig, *, user_id: str | None = None ) -> str: """Launch a batch experiment as a background task. Returns operation ID.""" from veupath_chatbot.services.experiment.store import get_experiment_store operation_id = f"op_{uuid4().hex[:12]}" batch_id = f"batch_{int(asyncio.get_running_loop().time() * 1000)}" await _register_experiment_operation(operation_id, "batch") async def _run() -> None: failed = False try: base = batch_config.base_config org_param = batch_config.organism_param_name results: list[Experiment] = [] store = get_experiment_store() for target in batch_config.target_organisms: params = dict(base.parameters) params[org_param] = target.organism org_config = ExperimentConfig( site_id=base.site_id, record_type=base.record_type, search_name=base.search_name, parameters=params, positive_controls=target.positive_controls if target.positive_controls is not None else list(base.positive_controls), negative_controls=target.negative_controls if target.negative_controls is not None else list(base.negative_controls), controls_search_name=base.controls_search_name, controls_param_name=base.controls_param_name, controls_value_format=base.controls_value_format, enable_cross_validation=base.enable_cross_validation, k_folds=base.k_folds, enrichment_types=list(base.enrichment_types), name=f"{base.name} ({target.organism})", description=base.description, optimization_specs=( copy.deepcopy(base.optimization_specs) if base.optimization_specs else None ), optimization_budget=base.optimization_budget, optimization_objective=base.optimization_objective, parameter_display_values=( dict(base.parameter_display_values) if base.parameter_display_values else None ), ) try: exp = await run_experiment( org_config, user_id=user_id, progress_callback=_make_progress_callback(operation_id), ) exp.batch_id = batch_id store.save(exp) results.append(exp) except Exception as exc: logger.error( "Batch organism experiment failed", organism=target.organism, error=str(exc), ) await _emit_to_redis( operation_id, "batch_complete", { "batchId": batch_id, "experiments": [experiment_to_json(e) for e in results], }, ) except Exception as exc: failed = True logger.error("Batch experiment failed", error=str(exc), exc_info=True) await _emit_to_redis(operation_id, "batch_error", {"error": str(exc)}) finally: await _finalize_operation(operation_id, failed=failed) spawn(_run()) return operation_id
[docs] async def start_benchmark( base_config: ExperimentConfig, control_sets: list[tuple[str, list[str], list[str], str | None, bool]], *, user_id: str | None = None, ) -> str: """Launch a benchmark suite as a background task. Returns operation ID.""" from veupath_chatbot.services.experiment.store import get_experiment_store operation_id = f"op_{uuid4().hex[:12]}" benchmark_id = f"bench_{int(asyncio.get_running_loop().time() * 1000)}" await _register_experiment_operation(operation_id, "benchmark") async def _run() -> None: failed = False try: async def _run_one( label: str, positives: list[str], negatives: list[str], control_set_id: str | None, is_primary: bool, ) -> Experiment | None: cfg = copy.deepcopy(base_config) cfg.positive_controls = positives cfg.negative_controls = negatives cfg.name = f"{base_config.name} [{label}]" cfg.control_set_id = control_set_id try: exp = await run_experiment( cfg, user_id=user_id, progress_callback=_make_progress_callback(operation_id), ) exp.benchmark_id = benchmark_id exp.control_set_label = label exp.is_primary_benchmark = is_primary store = get_experiment_store() store.save(exp) return exp except Exception as exc: logger.error( "Benchmark experiment failed", label=label, error=str(exc), ) return None tasks = [ _run_one(label, pos, neg, csid, primary) for label, pos, neg, csid, primary in control_sets ] results = await asyncio.gather(*tasks) completed = [r for r in results if r is not None] await _emit_to_redis( operation_id, "benchmark_complete", { "benchmarkId": benchmark_id, "experiments": [experiment_to_json(e) for e in completed], }, ) except Exception as exc: failed = True logger.error("Benchmark suite failed", error=str(exc), exc_info=True) await _emit_to_redis(operation_id, "benchmark_error", {"error": str(exc)}) finally: await _finalize_operation(operation_id, failed=failed) spawn(_run()) return operation_id