Source code for veupath_chatbot.integrations.vectorstore.ingest.pipeline

"""Generic concurrent pipeline utility for ingest workers.

Provides a reusable asyncio-based producer/worker/consumer pattern with:
- A pool of concurrent workers that process items via ``process_fn``
- A single consumer that batches results and flushes via ``flush_fn``
- Sentinel-based shutdown with proper queue joining
"""

import asyncio
from collections.abc import Awaitable, Callable, Sequence


[docs] async def run_concurrent_pipeline[InT, OutT]( *, items: Sequence[InT], process_fn: Callable[[InT], Awaitable[OutT | None]], flush_fn: Callable[[list[OutT]], Awaitable[None]], concurrency: int, batch_size: int, on_error: Callable[[InT, Exception], None] | None = None, ) -> None: """Run *items* through a concurrent worker pool, batching results to *flush_fn*. Parameters ---------- items: The input items to process. process_fn: Async callable that transforms a single input item into an output. Return ``None`` to skip the item (it will not be flushed). flush_fn: Async callable that receives a batch of outputs to persist/upsert. concurrency: Maximum number of concurrent worker tasks. batch_size: Number of outputs to accumulate before calling *flush_fn*. on_error: Optional sync callback invoked when *process_fn* raises. If not provided, errors are silently swallowed and the item is skipped. """ if not items: return conc = max(1, int(concurrency)) bs = max(1, int(batch_size)) jobs: asyncio.Queue[InT | None] = asyncio.Queue() results: asyncio.Queue[OutT | None] = asyncio.Queue() async def worker() -> None: while True: item = await jobs.get() if item is None: jobs.task_done() break try: result = await process_fn(item) if result is not None: await results.put(result) except Exception as exc: if on_error is not None: on_error(item, exc) finally: jobs.task_done() async def consumer() -> None: buffered: list[OutT] = [] while True: item = await results.get() if item is None: results.task_done() break buffered.append(item) results.task_done() if len(buffered) >= bs: await flush_fn(buffered) buffered = [] if buffered: await flush_fn(buffered) workers = [asyncio.create_task(worker()) for _ in range(conc)] consumer_task = asyncio.create_task(consumer()) for item in items: await jobs.put(item) for _ in workers: await jobs.put(None) await jobs.join() await results.put(None) await consumer_task for w in workers: await w