from collections.abc import Awaitable, Callable
from qdrant_client import AsyncQdrantClient
from veupath_chatbot.integrations.embeddings.openai_embeddings import OpenAIEmbeddings
from veupath_chatbot.integrations.vectorstore.collections import (
WDK_RECORD_TYPES_V1,
WDK_SEARCHES_V1,
)
from veupath_chatbot.integrations.vectorstore.ingest.pipeline import (
run_concurrent_pipeline,
)
from veupath_chatbot.integrations.vectorstore.ingest.utils import (
embed_and_upsert,
existing_point_ids,
)
from veupath_chatbot.integrations.vectorstore.qdrant_store import (
QdrantStore,
point_uuid,
)
from veupath_chatbot.integrations.veupathdb.param_utils import wdk_entity_name
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.types import JSONArray, JSONObject, JSONValue
logger = get_logger(__name__)
[docs]
async def filter_existing_record_types(
qdrant_client: AsyncQdrantClient,
record_type_docs: JSONArray,
site_id: str,
) -> JSONArray:
rt_ids: list[str] = []
for d in record_type_docs:
if isinstance(d, dict):
id_value = d.get("id")
if id_value is not None:
rt_ids.append(str(id_value))
existing_rts = await existing_point_ids(
qdrant_client=qdrant_client, collection=WDK_RECORD_TYPES_V1, ids=rt_ids
)
if existing_rts:
before = len(record_type_docs)
filtered_docs: JSONArray = []
for d in record_type_docs:
if isinstance(d, dict):
id_value = d.get("id")
if id_value is not None and str(id_value) not in existing_rts:
filtered_docs.append(d)
record_type_docs = filtered_docs
skipped = before - len(record_type_docs)
if skipped:
logger.info("WDK record types skipped", siteId=site_id, skipped=skipped)
return record_type_docs
[docs]
async def filter_existing_searches(
qdrant_client: AsyncQdrantClient,
searches_to_fetch: list[tuple[str, JSONObject]],
site_id: str,
) -> list[tuple[str, JSONObject]]:
enriched: list[tuple[str, JSONObject, str]] = []
for rt_name, s in searches_to_fetch:
if not isinstance(s, dict):
continue
if s.get("isInternal", False):
continue
search_name = wdk_entity_name(s)
if not search_name:
continue
enriched.append((rt_name, s, point_uuid(f"{site_id}:{rt_name}:{search_name}")))
search_ids = [sid for _, _, sid in enriched]
existing_searches = await existing_point_ids(
qdrant_client=qdrant_client, collection=WDK_SEARCHES_V1, ids=search_ids
)
if existing_searches:
before = len(enriched)
enriched = [t for t in enriched if t[2] not in existing_searches]
skipped = before - len(enriched)
if skipped:
logger.info("WDK searches skipped", siteId=site_id, skipped=skipped)
return [(rt_name, s) for rt_name, s, _ in enriched]
async def _upsert_docs_batch(
store: QdrantStore,
embedder: OpenAIEmbeddings,
collection: str,
docs: JSONArray,
) -> None:
if not docs:
return
ids: list[str | JSONValue] = []
texts: list[str] = []
payloads: list[JSONObject] = []
for d in docs:
if isinstance(d, dict):
text_value = d.get("text")
if isinstance(text_value, str):
ids.append(d.get("id"))
texts.append(text_value)
payload_raw = d.get("payload")
payloads.append(payload_raw if isinstance(payload_raw, dict) else {})
await embed_and_upsert(
store=store,
embedder=embedder,
collection=collection,
ids=ids,
texts=texts,
payloads=payloads,
)
[docs]
async def upsert_record_type_docs(
store: QdrantStore,
embedder: OpenAIEmbeddings,
record_type_docs: JSONArray,
) -> None:
await _upsert_docs_batch(store, embedder, WDK_RECORD_TYPES_V1, record_type_docs)
[docs]
async def upsert_search_docs_batch(
store: QdrantStore,
embedder: OpenAIEmbeddings,
buffered: JSONArray,
) -> None:
await _upsert_docs_batch(store, embedder, WDK_SEARCHES_V1, buffered)
[docs]
async def run_search_indexing_pipeline(
*,
searches_to_fetch: list[tuple[str, JSONObject]],
make_doc: Callable[[str, JSONObject], Awaitable[tuple[JSONObject | None, bool]]],
store: QdrantStore,
embedder: OpenAIEmbeddings,
concurrency: int,
batch_size: int,
site_id: str,
) -> None:
failed_details: int = 0
async def process(item: tuple[str, JSONObject]) -> JSONObject | None:
nonlocal failed_details
rt_name, s = item
doc, had_error = await make_doc(rt_name, s)
if had_error:
failed_details += 1
return doc
async def flush(batch: list[JSONObject]) -> None:
await upsert_search_docs_batch(store, embedder, list(batch))
await run_concurrent_pipeline(
items=searches_to_fetch,
process_fn=process,
flush_fn=flush,
concurrency=concurrency,
batch_size=batch_size,
)
if failed_details:
logger.warning(
"WDK search details failed", siteId=site_id, failed=failed_details
)