Selaa lähdekoodia

feat: add advanced workflow controls

Jax Docker 1 kuukausi sitten
vanhempi
sitoutus
51ebd18f4e

+ 48 - 0
README.md

@@ -561,6 +561,12 @@ Current behavior:
 - `knowledge-retrieval` / `retriever` nodes run keyword retrieval over inline or HTTP JSON documents
 - `tool` nodes persist resolved binding/tool metadata to `output_json`
 - default executors persist basic executor metadata to `output_json`
+- parallel fan-out is supported by defining multiple outgoing edges from one node
+- join nodes wait for predecessor completion with `config.join_policy`
+- loop/re-entry is supported with `config.allow_loop=true` and `config.max_iterations`
+- retry is supported with `config.retry_policy.max_attempts` and `retry_delay_seconds`
+- delayed scheduling and node timeout use `config.delay_seconds` and `config.timeout_seconds`
+- compensation nodes can be queued on failure with `config.compensation_node_id`
 
 Runtime template context:
 
@@ -605,6 +611,48 @@ Conditional edge example:
 ]
 ```
 
+Join node config example:
+
+```json
+{
+  "id": "join-results",
+  "type": "join",
+  "config": {
+    "join_policy": "all_completed"
+  }
+}
+```
+
+Loop and retry config example:
+
+```json
+{
+  "id": "poll-status",
+  "type": "tool",
+  "config": {
+    "allow_loop": true,
+    "max_iterations": 5,
+    "timeout_seconds": 30,
+    "retry_policy": {
+      "max_attempts": 3,
+      "retry_delay_seconds": 2
+    }
+  }
+}
+```
+
+Compensation config example:
+
+```json
+{
+  "id": "charge-card",
+  "type": "tool",
+  "config": {
+    "compensation_node_id": "refund-card"
+  }
+}
+```
+
 Template node config example:
 
 ```json

+ 2 - 0
libs/core-domain/src/core_domain/runtime_contracts.py

@@ -58,6 +58,8 @@ class NodeRunContract(BaseModel):
     status: NodeRunStatus
     output_text: str | None = None
     output_json: dict[str, JSONValue] | None = None
+    scheduled_time: datetime | None = None
+    timeout_time: datetime | None = None
     queued_time: datetime | None = None
     created_time: datetime
 

+ 31 - 0
services/runtime-service/alembic/versions/20260425_0007_add_node_run_scheduling.py

@@ -0,0 +1,31 @@
+"""add node run scheduling fields
+
+Revision ID: 20260425_0007
+Revises: 20260423_0006
+Create Date: 2026-04-25 16:20:00
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+import sqlalchemy as sa
+
+
+revision: str = "20260425_0007"
+down_revision: str | None = "20260423_0006"
+branch_labels: Sequence[str] | None = None
+depends_on: Sequence[str] | None = None
+
+
+def upgrade() -> None:
+    op.add_column("node_run", sa.Column("scheduled_time", sa.DateTime(), nullable=True))
+    op.add_column("node_run", sa.Column("timeout_time", sa.DateTime(), nullable=True))
+    op.create_index("ix_node_run_scheduled_time", "node_run", ["scheduled_time"], unique=False)
+    op.create_index("ix_node_run_timeout_time", "node_run", ["timeout_time"], unique=False)
+
+
+def downgrade() -> None:
+    op.drop_index("ix_node_run_timeout_time", table_name="node_run")
+    op.drop_index("ix_node_run_scheduled_time", table_name="node_run")
+    op.drop_column("node_run", "timeout_time")
+    op.drop_column("node_run", "scheduled_time")

+ 379 - 11
services/runtime-service/app/application/services.py

@@ -2,11 +2,13 @@ from datetime import datetime, timedelta
 
 from sqlalchemy.orm import Session
 
+from core_dsl import parse_workflow_definition
 from core_domain import (
     InitialNodeContract,
     NodeExecutionContextContract,
     NodeExecutionResultContract,
     NodeRunStatus,
+    WorkflowVersionContract,
     WorkflowRunStatus,
 )
 
@@ -18,10 +20,17 @@ from app.domain.repositories import (
     TraceSpanRepository,
     WorkflowRunRepository,
 )
-from app.infrastructure.executors import NodeExecutionDispatcher, build_node_execution_dispatcher_with_clients
+from app.infrastructure.executors import (
+    NodeExecutionDispatcher,
+    build_node_execution_dispatcher_with_clients,
+)
 from app.infrastructure.code_runner_client import CodeRunnerClient
 from app.infrastructure.model_gateway_client import ModelGatewayClient
-from app.infrastructure.planner import derive_initial_node, derive_node_config, derive_successor_nodes
+from app.infrastructure.planner import (
+    derive_initial_node,
+    derive_node_config,
+    derive_successor_nodes,
+)
 from app.infrastructure.tool_client import ToolServiceClient
 from app.infrastructure.workflow_client import WorkflowServiceClient
 from app.bootstrap.settings import RuntimeServiceSettings
@@ -76,12 +85,20 @@ class RuntimeApplicationService:
                 run_id=workflow_run.id,
                 current_node_count=1,
             )
+            initial_config = self._resolve_node_config(
+                tenant_id=payload.tenant_id,
+                workflow_version_id=payload.workflow_version_id,
+                node_id=initial_node.node_id,
+            )
+            scheduled_time, timeout_time = self._build_node_timing(initial_config)
             node_run = self.node_run_repository.create(
                 tenant_id=payload.tenant_id,
                 run_id=workflow_run.id,
                 node_id=initial_node.node_id,
                 node_type=initial_node.node_type,
                 status=initial_node.status,
+                scheduled_time=scheduled_time,
+                timeout_time=timeout_time,
             )
             self._log_event(
                 tenant_id=payload.tenant_id,
@@ -112,7 +129,10 @@ class RuntimeApplicationService:
         return workflow_run, node_run
 
     def list_runs(self, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]:
-        return self.workflow_run_repository.list_by_scope(tenant_id=tenant_id, session_id=session_id)
+        return self.workflow_run_repository.list_by_scope(
+            tenant_id=tenant_id,
+            session_id=session_id,
+        )
 
     def list_node_runs(self, tenant_id: str, run_id: str) -> list[NodeRun]:
         return self.node_run_repository.list_by_run(tenant_id=tenant_id, run_id=run_id)
@@ -202,6 +222,15 @@ class RuntimeApplicationService:
 
         if payload.status == "completed":
             self._schedule_successor_nodes(node_run)
+        if payload.status == "failed":
+            workflow_run = self.workflow_run_repository.get_by_id(node_run.run_id)
+            if workflow_run is not None:
+                node_config = self._resolve_node_config(
+                    tenant_id=node_run.tenant_id,
+                    workflow_version_id=workflow_run.workflow_version_id,
+                    node_id=node_run.node_id,
+                )
+                self._schedule_compensation_node(node_run=node_run, node_config=node_config)
 
         self._sync_workflow_run_status_from_nodes(
             tenant_id=node_run.tenant_id,
@@ -223,9 +252,38 @@ class RuntimeApplicationService:
             return None
 
         if node_run.status in {"completed", "failed", "skipped"}:
-            executor_name = self.execution_dispatcher.resolve_executor(node_run.node_type).executor_name
+            executor_name = self.execution_dispatcher.resolve_executor(
+                node_run.node_type
+            ).executor_name
             return workflow_run, node_run, executor_name
 
+        node_config = self._resolve_node_config(
+            tenant_id=node_run.tenant_id,
+            workflow_version_id=workflow_run.workflow_version_id,
+            node_id=node_run.node_id,
+        )
+        if self._node_has_timed_out(node_run):
+            timed_out_node_run = self.update_node_run_status(
+                node_run_id=node_run.id,
+                payload=NodeRunStatusUpdateRequest(
+                    status="failed",
+                    worker_key=payload.worker_key,
+                    error_code="node_timeout",
+                    error_message=f"node timed out: {node_run.node_id}",
+                    output_json={
+                        "timeout_time": node_run.timeout_time.isoformat()
+                        if node_run.timeout_time is not None
+                        else None,
+                    },
+                ),
+            )
+            if timed_out_node_run is None:
+                return None
+            executor_name = self.execution_dispatcher.resolve_executor(
+                node_run.node_type
+            ).executor_name
+            return workflow_run, timed_out_node_run, executor_name
+
         running_node_run = self.node_run_repository.update_status(
             node_run_id=node_run_id,
             status="running",
@@ -251,6 +309,7 @@ class RuntimeApplicationService:
             workflow_run=workflow_run,
             node_run=running_node_run,
             worker_key=payload.worker_key,
+            node_config_json=node_config,
         )
         executor_name = self.execution_dispatcher.resolve_executor(
             running_node_run.node_type
@@ -271,7 +330,10 @@ class RuntimeApplicationService:
         )
 
         try:
-            result, executor_name = self.execution_dispatcher.execute(context=context, request=payload)
+            result, executor_name = self.execution_dispatcher.execute(
+                context=context,
+                request=payload,
+            )
         except Exception as exc:
             result = NodeExecutionResultContract(
                 status="failed",
@@ -280,6 +342,60 @@ class RuntimeApplicationService:
                 error_message=str(exc),
             )
 
+        if result.status == "failed" and self._should_retry_node(
+            node_run=running_node_run,
+            node_config_json=context.node_config_json,
+        ):
+            retry_time, retry_timeout_time = self._build_retry_timing(context.node_config_json)
+            retried_node_run = self.node_run_repository.requeue_for_retry(
+                node_run_id=running_node_run.id,
+                scheduled_time=retry_time,
+                timeout_time=retry_timeout_time,
+                error_code=result.error_code,
+                error_message=result.error_message,
+                output_text=result.output_text,
+                output_json={
+                    **(result.output_json or {}),
+                    "retry_scheduled_time": retry_time.isoformat(),
+                    "retry_reason": result.error_code or "node_failed",
+                },
+            )
+            if retried_node_run is None:
+                return None
+            self.trace_span_repository.finish(
+                span_id=trace_span.id,
+                status="error",
+                error_code=result.error_code,
+                error_message=result.error_message,
+                attributes_json={
+                    "node_status": "queued",
+                    "executor_name": executor_name,
+                    "retry_scheduled": True,
+                    "attempt_no": retried_node_run.attempt_no,
+                },
+            )
+            self._log_event(
+                tenant_id=retried_node_run.tenant_id,
+                run_id=retried_node_run.run_id,
+                node_run_id=retried_node_run.id,
+                event_type="node_retry_scheduled",
+                message=f"node retry scheduled: {retried_node_run.node_id}",
+                detail_json={
+                    "node_id": retried_node_run.node_id,
+                    "attempt_no": retried_node_run.attempt_no,
+                    "scheduled_time": retry_time.isoformat(),
+                    "error_code": result.error_code,
+                },
+            )
+            self._sync_workflow_run_status_from_nodes(
+                tenant_id=retried_node_run.tenant_id,
+                run_id=retried_node_run.run_id,
+            )
+            workflow_run = self.workflow_run_repository.get_by_id(retried_node_run.run_id)
+            if workflow_run is None:
+                return None
+            return workflow_run, retried_node_run, executor_name
+
         final_node_run = self.update_node_run_status(
             node_run_id=running_node_run.id,
             payload=NodeRunStatusUpdateRequest(
@@ -477,17 +593,52 @@ class RuntimeApplicationService:
             run_id=node_run.run_id,
             node_ids=[item.node_id for item in successor_nodes],
         )
-        existing_node_ids = {item.node_id for item in existing_nodes}
+        existing_node_counts: dict[str, int] = {}
+        for item in existing_nodes:
+            existing_node_counts[item.node_id] = existing_node_counts.get(item.node_id, 0) + 1
 
         for successor in successor_nodes:
-            if successor.node_id in existing_node_ids:
+            successor_config = derive_node_config(workflow_version, successor.node_id)
+            if not self._is_join_ready(
+                workflow_version=workflow_version,
+                run_node_runs=self.node_run_repository.list_by_run(
+                    tenant_id=node_run.tenant_id,
+                    run_id=node_run.run_id,
+                ),
+                successor_node_id=successor.node_id,
+                successor_node_type=successor.node_type,
+                successor_config=successor_config,
+            ):
+                self._log_event(
+                    tenant_id=node_run.tenant_id,
+                    run_id=node_run.run_id,
+                    node_run_id=None,
+                    event_type="join_waiting",
+                    message=f"join node waiting for predecessors: {successor.node_id}",
+                    detail_json={
+                        "node_id": successor.node_id,
+                        "source_node_id": node_run.node_id,
+                    },
+                )
+                continue
+            if not self._can_schedule_repeated_node(
+                successor_config,
+                existing_count=existing_node_counts.get(successor.node_id, 0),
+            ):
                 continue
+            scheduled_time, timeout_time = self._build_node_timing(successor_config)
             self.node_run_repository.create(
                 tenant_id=node_run.tenant_id,
                 run_id=node_run.run_id,
+                parent_node_run_id=node_run.id,
                 node_id=successor.node_id,
                 node_type=successor.node_type,
                 status=successor.status,
+                scheduled_time=scheduled_time,
+                timeout_time=timeout_time,
+            )
+            existing_node_counts[successor.node_id] = (
+                existing_node_counts.get(successor.node_id, 0) + 1
             )
             self._log_event(
                 tenant_id=node_run.tenant_id,
@@ -509,6 +660,7 @@ class RuntimeApplicationService:
         workflow_run: WorkflowRun,
         node_run: NodeRun,
         worker_key: str | None,
+        node_config_json: dict[str, JSONValue] | None = None,
     ) -> NodeExecutionContextContract:
         run_state_json, node_output_json_by_node_id, node_output_text_by_node_id = (
             self._build_run_state_maps(
@@ -522,7 +674,9 @@ class RuntimeApplicationService:
             node_run_id=node_run.id,
             node_id=node_run.node_id,
             node_type=node_run.node_type,
-            node_config_json=self._resolve_node_config(
+            node_config_json=node_config_json
+            if node_config_json is not None
+            else self._resolve_node_config(
                 tenant_id=node_run.tenant_id,
                 workflow_version_id=workflow_run.workflow_version_id,
                 node_id=node_run.node_id,
@@ -562,6 +716,219 @@ class RuntimeApplicationService:
 
         return run_state_json, node_output_json_by_node_id, node_output_text_by_node_id
 
+    def _build_node_timing(
+        self,
+        node_config_json: dict[str, JSONValue],
+    ) -> tuple[datetime, datetime | None]:
+        now = datetime.utcnow()
+        delay_seconds = self._read_int_value(node_config_json, "delay_seconds", default=0)
+        timeout_seconds = self._read_int_value(node_config_json, "timeout_seconds", default=0)
+        scheduled_time = now + timedelta(seconds=max(delay_seconds, 0))
+        timeout_time = (
+            scheduled_time + timedelta(seconds=timeout_seconds)
+            if timeout_seconds > 0
+            else None
+        )
+        return scheduled_time, timeout_time
+
+    def _node_has_timed_out(self, node_run: NodeRun) -> bool:
+        return node_run.timeout_time is not None and node_run.timeout_time <= datetime.utcnow()
+
+    def _should_retry_node(
+        self,
+        *,
+        node_run: NodeRun,
+        node_config_json: dict[str, JSONValue],
+    ) -> bool:
+        retry_policy = self._read_dict_value(node_config_json, "retry_policy")
+        max_attempts = self._read_int_value(retry_policy, "max_attempts", default=1)
+        return max_attempts > node_run.attempt_no
+
+    def _read_retry_delay_seconds(self, node_config_json: dict[str, JSONValue]) -> int:
+        retry_policy = self._read_dict_value(node_config_json, "retry_policy")
+        return self._read_int_value(retry_policy, "retry_delay_seconds", default=0)
+
+    def _build_retry_timing(
+        self,
+        node_config_json: dict[str, JSONValue],
+    ) -> tuple[datetime, datetime | None]:
+        retry_time = datetime.utcnow() + timedelta(
+            seconds=self._read_retry_delay_seconds(node_config_json)
+        )
+        timeout_seconds = self._read_int_value(node_config_json, "timeout_seconds", default=0)
+        timeout_time = (
+            retry_time + timedelta(seconds=timeout_seconds)
+            if timeout_seconds > 0
+            else None
+        )
+        return retry_time, timeout_time
+
+    def _is_join_ready(
+        self,
+        *,
+        workflow_version: WorkflowVersionContract,
+        run_node_runs: list[NodeRun],
+        successor_node_id: str,
+        successor_node_type: str,
+        successor_config: dict[str, JSONValue],
+    ) -> bool:
+        join_policy = self._read_string_value(successor_config, "join_policy")
+        if join_policy is None and successor_node_type != "join":
+            return True
+        workflow = self._parse_workflow(workflow_version)
+        if workflow is None:
+            return True
+        predecessor_ids = [
+            edge.source for edge in workflow.edges if edge.target == successor_node_id
+        ]
+        if not predecessor_ids:
+            return True
+        completed_node_ids = {
+            item.node_id
+            for item in run_node_runs
+            if item.status in {"completed", "skipped"}
+        }
+        if join_policy in {None, "all_completed"}:
+            return all(predecessor_id in completed_node_ids for predecessor_id in predecessor_ids)
+        if join_policy == "any_completed":
+            return any(predecessor_id in completed_node_ids for predecessor_id in predecessor_ids)
+        return True
+
+    def _can_schedule_repeated_node(
+        self,
+        node_config_json: dict[str, JSONValue],
+        *,
+        existing_count: int,
+    ) -> bool:
+        if existing_count == 0:
+            return True
+        allow_loop = self._read_bool_value(node_config_json, "allow_loop", default=False)
+        max_iterations = self._read_int_value(node_config_json, "max_iterations", default=1)
+        return allow_loop and existing_count < max_iterations
+
+    def _schedule_compensation_node(
+        self,
+        *,
+        node_run: NodeRun,
+        node_config: dict[str, JSONValue],
+    ) -> None:
+        compensation_node_id = self._read_string_value(node_config, "compensation_node_id")
+        if compensation_node_id is None:
+            compensation_config = self._read_dict_value(node_config, "compensation")
+            compensation_node_id = self._read_string_value(compensation_config, "node_id")
+        if compensation_node_id is None:
+            return
+
+        workflow_run = self.workflow_run_repository.get_by_id(node_run.run_id)
+        if workflow_run is None:
+            return
+        compensation_config = self._resolve_node_config(
+            tenant_id=node_run.tenant_id,
+            workflow_version_id=workflow_run.workflow_version_id,
+            node_id=compensation_node_id,
+        )
+        existing_nodes = self.node_run_repository.list_by_run_and_node_ids(
+            tenant_id=node_run.tenant_id,
+            run_id=node_run.run_id,
+            node_ids=[compensation_node_id],
+        )
+        if existing_nodes and not self._can_schedule_repeated_node(
+            compensation_config,
+            existing_count=len(existing_nodes),
+        ):
+            return
+        compensation_node_type = self._resolve_workflow_node_type(
+            tenant_id=node_run.tenant_id,
+            workflow_version_id=workflow_run.workflow_version_id,
+            node_id=compensation_node_id,
+        ) or "compensation"
+        scheduled_time, timeout_time = self._build_node_timing(compensation_config)
+        created = self.node_run_repository.create(
+            tenant_id=node_run.tenant_id,
+            run_id=node_run.run_id,
+            parent_node_run_id=node_run.id,
+            node_id=compensation_node_id,
+            node_type=compensation_node_type,
+            status="queued",
+            scheduled_time=scheduled_time,
+            timeout_time=timeout_time,
+        )
+        self._log_event(
+            tenant_id=node_run.tenant_id,
+            run_id=node_run.run_id,
+            node_run_id=created.id,
+            event_type="compensation_queued",
+            message=f"compensation node queued: {compensation_node_id}",
+            detail_json={
+                "failed_node_id": node_run.node_id,
+                "compensation_node_id": compensation_node_id,
+            },
+        )
+
+    def _parse_workflow(self, workflow_version: WorkflowVersionContract):
+        return parse_workflow_definition(workflow_version.dsl_json)
+
+    def _resolve_workflow_node_type(
+        self,
+        *,
+        tenant_id: str,
+        workflow_version_id: str,
+        node_id: str,
+    ) -> str | None:
+        if self.workflow_client is None:
+            return None
+        workflow_version = self.workflow_client.get_workflow_version(
+            tenant_id=tenant_id,
+            workflow_version_id=workflow_version_id,
+        )
+        workflow = self._parse_workflow(workflow_version)
+        if workflow is None:
+            return None
+        for node in workflow.nodes:
+            if node.id == node_id:
+                return node.type
+        return None
+
+    def _read_string_value(self, payload: dict[str, JSONValue], key: str) -> str | None:
+        value = payload.get(key)
+        if isinstance(value, str) and value:
+            return value
+        return None
+
+    def _read_bool_value(
+        self,
+        payload: dict[str, JSONValue],
+        key: str,
+        *,
+        default: bool,
+    ) -> bool:
+        value = payload.get(key)
+        if isinstance(value, bool):
+            return value
+        return default
+
+    def _read_int_value(
+        self,
+        payload: dict[str, JSONValue],
+        key: str,
+        *,
+        default: int,
+    ) -> int:
+        value = payload.get(key)
+        if isinstance(value, int) and not isinstance(value, bool):
+            return value
+        return default
+
+    def _read_dict_value(
+        self,
+        payload: dict[str, JSONValue],
+        key: str,
+    ) -> dict[str, JSONValue]:
+        value = payload.get(key)
+        if isinstance(value, dict):
+            return {str(item_key): item_value for item_key, item_value in value.items()}
+        return {}
+
     def _sync_workflow_run_status_from_nodes(self, *, tenant_id: str, run_id: str) -> None:
         node_runs = self.node_run_repository.list_by_run(tenant_id=tenant_id, run_id=run_id)
         if not node_runs:
@@ -597,15 +964,16 @@ class RuntimeApplicationService:
     ) -> tuple[WorkflowRunStatus, str | None, str | None]:
         statuses = {node_run.status for node_run in node_runs}
 
+        active_statuses: set[NodeRunStatus] = {"pending", "queued", "running"}
+        if statuses.intersection(active_statuses):
+            return "running", None, None
+
         if "failed" in statuses:
             failed_node = next((item for item in node_runs if item.status == "failed"), None)
             error_code = failed_node.error_code if failed_node is not None else None
             error_message = failed_node.error_message if failed_node is not None else None
             return "failed", error_code, error_message
 
-        if "running" in statuses:
-            return "running", None, None
-
         terminal_statuses: set[NodeRunStatus] = {"completed", "skipped"}
         if statuses and statuses.issubset(terminal_statuses):
             return "completed", None, None

+ 2 - 0
services/runtime-service/app/db/models/node_run.py

@@ -19,6 +19,8 @@ class NodeRun(TenantMixin, AuditMixin, VersionMixin, Base):
     status: Mapped[str] = mapped_column(String(32), default="pending", index=True)
     worker_key: Mapped[str | None] = mapped_column(String(128), nullable=True)
     lease_expire_time: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+    scheduled_time: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, index=True)
+    timeout_time: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, index=True)
     queued_time: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
     started_time: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
     finished_time: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

+ 52 - 1
services/runtime-service/app/domain/repositories.py

@@ -1,6 +1,6 @@
 from datetime import datetime
 
-from sqlalchemy import select
+from sqlalchemy import or_, select
 from sqlalchemy.orm import Session
 
 from app.db.models import ExecutionLog, NodeArtifact, NodeRun, TraceSpan, WorkflowRun
@@ -107,15 +107,21 @@ class NodeRunRepository:
         node_id: str,
         node_type: str,
         status: str,
+        scheduled_time: datetime | None = None,
+        timeout_time: datetime | None = None,
+        parent_node_run_id: str | None = None,
     ) -> NodeRun:
         now = datetime.utcnow()
         entity = NodeRun(
             tenant_id=tenant_id,
             run_id=run_id,
+            parent_node_run_id=parent_node_run_id,
             node_id=node_id,
             node_type=node_type,
             status=status,
             queued_time=now,
+            scheduled_time=scheduled_time or now,
+            timeout_time=timeout_time,
         )
         self.db.add(entity)
         self.db.commit()
@@ -157,6 +163,12 @@ class NodeRunRepository:
             .where(NodeRun.tenant_id == tenant_id)
             .where(NodeRun.run_id == run_id)
             .where(NodeRun.status == "queued")
+            .where(
+                or_(
+                    NodeRun.scheduled_time.is_(None),
+                    NodeRun.scheduled_time <= datetime.utcnow(),
+                )
+            )
             .order_by(NodeRun.created_time.asc())
             .limit(1)
         )
@@ -172,6 +184,12 @@ class NodeRunRepository:
             select(NodeRun)
             .join(WorkflowRun, NodeRun.run_id == WorkflowRun.id)
             .where(NodeRun.status == "queued")
+            .where(
+                or_(
+                    NodeRun.scheduled_time.is_(None),
+                    NodeRun.scheduled_time <= datetime.utcnow(),
+                )
+            )
             .order_by(WorkflowRun.priority.desc(), NodeRun.created_time.asc())
             .with_for_update(skip_locked=True)
             .limit(1)
@@ -203,6 +221,7 @@ class NodeRunRepository:
             entity.status = "queued"
             entity.worker_key = None
             entity.lease_expire_time = None
+            entity.scheduled_time = now_time
             entity.queued_time = now_time
             entity.started_time = None
             entity.finished_time = None
@@ -246,6 +265,38 @@ class NodeRunRepository:
         self.db.refresh(entity)
         return entity
 
+    def requeue_for_retry(
+        self,
+        *,
+        node_run_id: str,
+        scheduled_time: datetime,
+        timeout_time: datetime | None,
+        error_code: str | None,
+        error_message: str | None,
+        output_text: str | None,
+        output_json: dict[str, JSONValue] | None,
+    ) -> NodeRun | None:
+        entity = self.db.get(NodeRun, node_run_id)
+        if entity is None:
+            return None
+
+        entity.status = "queued"
+        entity.attempt_no += 1
+        entity.worker_key = None
+        entity.lease_expire_time = None
+        entity.scheduled_time = scheduled_time
+        entity.timeout_time = timeout_time
+        entity.queued_time = datetime.utcnow()
+        entity.started_time = None
+        entity.finished_time = None
+        entity.error_code = error_code
+        entity.error_message = error_message
+        entity.output_text = output_text
+        entity.output_json = output_json
+        self.db.commit()
+        self.db.refresh(entity)
+        return entity
+
 
 class ExecutionLogRepository:
     def __init__(self, db: Session) -> None: