"""Run delegation graph nodes respecting dependency ordering."""
import asyncio
from collections.abc import Awaitable, Callable
from veupath_chatbot.platform.logging import get_logger
from veupath_chatbot.platform.tool_errors import tool_error
from veupath_chatbot.platform.types import (
JSONArray,
JSONObject,
as_json_array,
as_json_object,
)
logger = get_logger(__name__)
[docs]
async def run_nodes_with_dependencies(
*,
nodes_by_id: dict[str, JSONObject],
dependents: dict[str, list[str]],
max_concurrency: int,
run_node: Callable[[str, JSONObject, str | None], Awaitable[JSONObject]],
format_dependency_context: Callable[..., str | None],
results_by_id: dict[str, JSONObject] | None = None,
) -> tuple[JSONArray, dict[str, JSONObject]]:
"""Execute nodes concurrently while honoring their depends_on edges."""
node_ids = set(nodes_by_id.keys())
remaining_deps: dict[str, set[str]] = {}
for node_id, node in nodes_by_id.items():
depends_on_value = node.get("depends_on")
if isinstance(depends_on_value, list):
deps_list = as_json_array(depends_on_value)
remaining_deps[node_id] = {
dep for dep in deps_list if isinstance(dep, str) and dep in node_ids
}
else:
remaining_deps[node_id] = set()
# Detect cycles: nodes that can never become ready
all_scheduled: set[str] = set()
ready = [node_id for node_id, deps in remaining_deps.items() if not deps]
running: dict[asyncio.Task[JSONObject], str] = {}
results: JSONArray = []
if results_by_id is None:
results_by_id = {}
semaphore = asyncio.Semaphore(max(1, int(max_concurrency)))
async def guarded_run(
node_id: str, node: JSONObject, dependency_context: str | None
) -> JSONObject:
async with semaphore:
return await run_node(node_id, node, dependency_context)
while ready or running:
while ready and len(running) < max(1, int(max_concurrency)):
node_id = ready.pop()
all_scheduled.add(node_id)
node = nodes_by_id[node_id]
dependency_context = format_dependency_context(
task_id=node_id,
tasks_by_id=nodes_by_id,
results_by_id=results_by_id,
)
running_task = asyncio.create_task(
guarded_run(node_id, node, dependency_context)
)
running[running_task] = node_id
if not running:
break
done, _ = await asyncio.wait(
running.keys(), return_when=asyncio.FIRST_COMPLETED
)
for finished in done:
finished_id = running.pop(finished)
result = finished.result()
results.append(result)
if isinstance(result, dict):
results_by_id[finished_id] = result
for child in dependents.get(finished_id, []):
remaining_deps[child].discard(finished_id)
if not remaining_deps[child]:
ready.append(child)
# Report any nodes that were never scheduled (circular dependency)
unscheduled = node_ids - all_scheduled
if unscheduled:
logger.error(
"Circular dependency detected — nodes never scheduled",
unscheduled_nodes=sorted(unscheduled),
)
for node_id in unscheduled:
results.append(
tool_error(
"CIRCULAR_DEPENDENCY",
f"Task '{node_id}' has circular dependencies and was skipped.",
)
)
return results, results_by_id
[docs]
def partition_task_results(
results: JSONArray,
) -> tuple[JSONArray, JSONArray]:
"""Split results into validated and rejected, preserving stable shape.
:param results: Results array.
"""
validated: JSONArray = []
rejected: JSONArray = []
for result_value in results:
if not isinstance(result_value, dict):
continue
result = as_json_object(result_value)
steps_value = result.get("steps")
steps: JSONArray = (
as_json_array(steps_value) if isinstance(steps_value, list) else []
)
if not steps:
payload = tool_error(
"NO_STEPS_CREATED",
"No steps created for the subtask.",
)
id_value = result.get("id")
task_value = result.get("task")
payload.update(
{
"id": id_value,
"task": task_value,
}
)
rejected.append(payload)
validated.append(
{
"id": id_value,
"task": task_value,
"steps": [],
"notes": result.get("notes"),
}
)
continue
validated.append(
{
"id": result.get("id"),
"task": result.get("task"),
"steps": steps,
"notes": result.get("notes"),
}
)
return validated, rejected