Source code for veupath_chatbot.services.chat.mention_context

"""Build rich context blocks for @-mentioned strategies and experiments.

When a user @-mentions a strategy or experiment in chat, we load the full
entity and format a human-readable context block that gets appended to the
system prompt so the model has complete information from the start.
"""

import json
from typing import Literal

from veupath_chatbot.persistence.repositories.stream import StreamRepository
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.services.experiment.store import get_experiment_store
from veupath_chatbot.services.experiment.types import ExperimentMetrics

logger = get_logger(__name__)

MentionType = Literal["strategy", "experiment"]


[docs] async def build_mention_context( mentions: list[dict[str, str]], stream_repo: StreamRepository, ) -> str: """Build concatenated context blocks for all mentions. :param mentions: List of ``{"type": ..., "id": ..., "displayName": ...}`` dicts. :param stream_repo: Repository for loading stream projections. :returns: Markdown context string (empty if no mentions resolved). """ blocks: list[str] = [] for m in mentions: m_type = m.get("type") m_id = m.get("id", "") if m_type == "strategy": block = await _build_strategy_context(m_id, stream_repo) if block: blocks.append(block) elif m_type == "experiment": block = await _build_experiment_context(m_id) if block: blocks.append(block) else: logger.debug("Unknown mention type", mention_type=m_type, mention_id=m_id) return "\n\n".join(blocks)
async def _build_strategy_context( strategy_id: str, stream_repo: StreamRepository, ) -> str | None: """Load a stream projection and format a rich context block.""" from uuid import UUID try: sid = UUID(strategy_id) except ValueError: logger.warning("Invalid strategy mention ID", strategy_id=strategy_id) return None projection = await stream_repo.get_projection(sid) if not projection: logger.warning("Mentioned strategy not found", strategy_id=strategy_id) return None lines: list[str] = [ f'## Referenced Strategy: "{projection.name}"', f"- **ID**: {projection.stream_id}", f"- **Record type**: {projection.record_type or 'unknown'}", ] steps = projection.steps if isinstance(steps, list) and steps: lines.append(f"- **Steps** ({len(steps)}):") lines.append("") for i, step in enumerate(steps): if not isinstance(step, dict): continue display_name = ( step.get("displayName") or step.get("searchName") or f"Step {i + 1}" ) kind = step.get("kind") or "search" step_id = step.get("id", "?") lines.append(f"### Step {i + 1}: {display_name} ({kind}) [id={step_id}]") search_name = step.get("searchName") if search_name: lines.append(f"- Search: `{search_name}`") params = step.get("parameters") if isinstance(params, dict) and params: non_empty = {k: v for k, v in params.items() if v} if non_empty: param_strs = [ f"`{k}`: {_truncate(str(v), 80)}" for k, v in non_empty.items() ] lines.append(f"- Parameters: {', '.join(param_strs)}") result_count = step.get("resultCount") if result_count is not None: lines.append(f"- Result count: {result_count}") primary = step.get("primaryInputStepId") secondary = step.get("secondaryInputStepId") operator = step.get("operator") if primary and secondary and operator: lines.append( f"- Combines: step {primary} **{operator}** step {secondary}" ) elif primary: lines.append(f"- Input: step {primary}") lines.append("") elif isinstance(projection.plan, dict): lines.append("") lines.append("### Strategy plan (AST):") lines.append("```json") lines.append(json.dumps(projection.plan, indent=2, default=str)[:4000]) lines.append("```") return "\n".join(lines) async def _build_experiment_context(experiment_id: str) -> str | None: """Load an experiment and format a rich context block.""" store = get_experiment_store() experiment = await store.aget(experiment_id) if not experiment: logger.warning("Mentioned experiment not found", experiment_id=experiment_id) return None cfg = experiment.config lines: list[str] = [ f'## Referenced Experiment: "{cfg.name or experiment.id}"', f"- **Status**: {experiment.status}", f"- **Search**: `{cfg.search_name}` on `{cfg.record_type}`", ] if cfg.parameters: non_empty = {k: v for k, v in cfg.parameters.items() if v} if non_empty: param_strs = [ f"`{k}`: {_truncate(str(v), 60)}" for k, v in non_empty.items() ] lines.append(f"- **Parameters**: {', '.join(param_strs)}") if cfg.positive_controls: ids = ", ".join(cfg.positive_controls[:20]) suffix = ( f" ... ({len(cfg.positive_controls)} total)" if len(cfg.positive_controls) > 20 else "" ) lines.append( f"- **Positive controls** ({len(cfg.positive_controls)}): {ids}{suffix}" ) if cfg.negative_controls: ids = ", ".join(cfg.negative_controls[:20]) suffix = ( f" ... ({len(cfg.negative_controls)} total)" if len(cfg.negative_controls) > 20 else "" ) lines.append( f"- **Negative controls** ({len(cfg.negative_controls)}): {ids}{suffix}" ) metrics = experiment.metrics if metrics: lines.append("") lines.append("### Metrics") lines.append(_format_metrics(metrics)) cv = experiment.cross_validation if cv: lines.append("") lines.append( f"### Cross-validation ({cv.k}-fold): " f"overfitting={cv.overfitting_level} (score={cv.overfitting_score:.3f})" ) lines.append(f"- Mean F1: {cv.mean_metrics.f1_score:.4f}") lines.append(f"- Mean sensitivity: {cv.mean_metrics.sensitivity:.4f}") lines.append(f"- Mean specificity: {cv.mean_metrics.specificity:.4f}") for er in experiment.enrichment_results[:3]: if er.terms: lines.append("") lines.append( f"### Enrichment: {er.analysis_type} ({er.total_genes_analyzed} genes)" ) for term in er.terms[:8]: lines.append( f"- {term.term_name} ({term.gene_count} genes, " f"p={term.p_value:.2e}, FDR={term.fdr:.2e})" ) if len(er.terms) > 8: lines.append(f"- ... {len(er.terms) - 8} more terms") if experiment.optimization_result: best = experiment.optimization_result.get("bestTrial") if isinstance(best, dict): lines.append("") lines.append("### Optimization result") lines.append(f"- Best score: {best.get('score', '?')}") params = best.get("parameters") if isinstance(params, dict): param_strs = [f"`{k}`: {v}" for k, v in params.items()] lines.append(f"- Best parameters: {', '.join(param_strs)}") return "\n".join(lines) def _format_metrics(m: ExperimentMetrics) -> str: """Format metrics as a compact table.""" return ( f"| Metric | Value |\n" f"|--------|-------|\n" f"| Sensitivity | {m.sensitivity:.4f} |\n" f"| Specificity | {m.specificity:.4f} |\n" f"| Precision | {m.precision:.4f} |\n" f"| F1 Score | {m.f1_score:.4f} |\n" f"| MCC | {m.mcc:.4f} |\n" f"| Balanced Accuracy | {m.balanced_accuracy:.4f} |\n" f"| Total Results | {m.total_results} |\n" f"| Confusion Matrix | TP={m.confusion_matrix.true_positives} FP={m.confusion_matrix.false_positives} " f"FN={m.confusion_matrix.false_negatives} TN={m.confusion_matrix.true_negatives} |" ) def _truncate(s: str, max_len: int) -> str: """Truncate a string with ellipsis.""" if len(s) <= max_len: return s # Too small for "..." suffix — just hard-truncate. if max_len < 4: return s[:max_len] return s[: max_len - 3] + "..."