Source code for veupath_chatbot.integrations.embeddings.openai_embeddings

import asyncio
from collections.abc import Iterable
from dataclasses import dataclass, field

from veupath_chatbot.platform.config import get_settings
from veupath_chatbot.platform.errors import InternalError


def _chunks(items: list[str], *, size: int) -> Iterable[list[str]]:
    if size <= 0:
        raise ValueError("chunk size must be > 0")
    for i in range(0, len(items), size):
        yield items[i : i + size]


def _resolve_embeddings_config() -> tuple[str, str | None]:
    """Return (api_key, base_url) for the embeddings client.

    Priority:
      1. Explicit ``EMBEDDINGS_BASE_URL`` → use that with the OpenAI key.
      2. ``OPENAI_API_KEY`` set → use OpenAI defaults.
      3. Ollama configured (``OLLAMA_BASE_URL``) → use Ollama as fallback.
    """
    settings = get_settings()

    if settings.embeddings_base_url:
        return settings.openai_api_key or "local", settings.embeddings_base_url

    if settings.openai_api_key:
        return settings.openai_api_key, None

    # Fallback: Ollama for embeddings when no OpenAI key is available.
    return "ollama", settings.ollama_base_url


[docs] @dataclass(frozen=True) class OpenAIEmbeddings: """Wrapper around OpenAI-compatible embeddings with batching. Works with OpenAI, Ollama, or any server exposing ``/v1/embeddings``. """ model: str batch_size: int = 128 base_url: str | None = field(default=None)
[docs] async def embed_texts(self, texts: list[str]) -> list[list[float]]: if not texts: return [] try: from openai import AsyncOpenAI except Exception as exc: # pragma: no cover raise InternalError( title="OpenAI SDK not available", detail="Install `openai` (or `kani[openai]`) to enable embeddings.", ) from exc api_key, resolved_base = _resolve_embeddings_config() effective_base = self.base_url or resolved_base async with AsyncOpenAI(api_key=api_key, base_url=effective_base) as client: vectors: list[list[float]] = [] for batch in _chunks(texts, size=self.batch_size): resp = await client.embeddings.create(model=self.model, input=batch) vectors.extend([d.embedding for d in resp.data]) await asyncio.sleep(0) if len(vectors) != len(texts): # pragma: no cover (SDK contract) raise InternalError(title="Embedding count mismatch") return vectors
[docs] async def embed_one(*, text: str, model: str) -> list[float]: """Convenience helper for one-off vector size detection.""" embedder = OpenAIEmbeddings(model=model, batch_size=1) return (await embedder.embed_texts([text]))[0]