"""Grouped runtime orchestration for authored flows."""
from __future__ import annotations
from queue import Queue
import threading
from typing import TYPE_CHECKING, Callable
from data_engine.core.primitives import FlowContext
from data_engine.runtime.execution.single import FlowRuntime, RuntimeCacheLedgerService, default_runtime_cache_ledger_service
from data_engine.runtime.runtime_db import RuntimeCacheLedger
from data_engine.runtime.stop import RuntimeStopController
if TYPE_CHECKING:
from data_engine.core.flow import Flow
[docs]
class GroupedFlowRuntime:
"""Grouped orchestrator: sequential within a group, parallel across groups."""
def __init__(
self,
flows: tuple["Flow", ...],
*,
continuous: bool,
runtime_stop_event: threading.Event | None = None,
flow_stop_event: threading.Event | None = None,
status_callback: Callable[[str], None] | None = None,
runtime_ledger: RuntimeCacheLedger | None = None,
runtime_ledger_service: RuntimeCacheLedgerService | None = None,
runtime_ledger_factory: Callable[[], RuntimeCacheLedger] | None = None,
run_stop_controller: RuntimeStopController | None = None,
) -> None:
self.flows = tuple(flows)
self.continuous = continuous
self.runtime_stop_event = runtime_stop_event
self.flow_stop_event = flow_stop_event
self.run_stop_controller = run_stop_controller or RuntimeStopController()
self.status_callback = status_callback
self._runtime_ledger_service = runtime_ledger_service or default_runtime_cache_ledger_service()
self._runtime_ledger_factory = runtime_ledger_factory or self._runtime_ledger_service.open_runtime_cache_ledger
self._owns_runtime_ledger = runtime_ledger is None
self.runtime_ledger = runtime_ledger or self._runtime_ledger_factory()
[docs]
def run(self) -> list[FlowContext]:
grouped = self._grouped_flows()
if len(grouped) <= 1:
only = next(iter(grouped.values()), ())
return FlowRuntime(
tuple(only),
continuous=self.continuous,
runtime_stop_event=self.runtime_stop_event,
flow_stop_event=self.flow_stop_event,
status_callback=self.status_callback,
runtime_ledger=self.runtime_ledger,
runtime_ledger_service=self._runtime_ledger_service,
run_stop_controller=self.run_stop_controller,
).run()
results_by_group: dict[str, list[FlowContext]] = {name: [] for name in grouped}
errors: Queue[tuple[str, Exception]] = Queue()
threads: list[threading.Thread] = []
internal_runtime_stop = self.runtime_stop_event or threading.Event()
internal_flow_stop = self.flow_stop_event or threading.Event()
def run_group(group_name: str, group_flows: tuple["Flow", ...]) -> None:
try:
runtime = FlowRuntime(
group_flows,
continuous=self.continuous,
runtime_stop_event=internal_runtime_stop,
flow_stop_event=internal_flow_stop,
status_callback=self.status_callback,
runtime_ledger=self.runtime_ledger,
runtime_ledger_service=self._runtime_ledger_service,
runtime_ledger_factory=self._runtime_ledger_factory,
run_stop_controller=self.run_stop_controller,
)
results_by_group[group_name] = runtime.run()
except Exception as exc: # pragma: no cover
errors.put((group_name, exc))
if not self.continuous:
internal_runtime_stop.set()
try:
for group_name, group_flows in grouped.items():
thread = threading.Thread(target=run_group, args=(group_name, group_flows), daemon=True)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
if not self.continuous and not errors.empty():
_, exc = errors.get()
raise exc
ordered_results: list[FlowContext] = []
for group_name in grouped:
ordered_results.extend(results_by_group[group_name])
return ordered_results
finally:
if self._owns_runtime_ledger:
self.runtime_ledger.close()
def _grouped_flows(self) -> dict[str, tuple["Flow", ...]]:
grouped: dict[str, list["Flow"]] = {}
for index, flow in enumerate(self.flows):
key = flow.group or f"group-{index}"
grouped.setdefault(key, []).append(flow)
return {name: tuple(group_flows) for name, group_flows in grouped.items()}
__all__ = ["GroupedFlowRuntime"]