Quellcode durchsuchen

feat: wire redis task queues into workers

Jax Docker vor 1 Monat
Ursprung
Commit
495582b5da

+ 70 - 0
libs/core-shared/src/core_shared/task_queue.py

@@ -0,0 +1,70 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Protocol
+
+from core_shared.types import JSONValue
+
+if TYPE_CHECKING:
+    from redis import Redis
+
+
+AGENT_RUN_QUEUE = "agent-platform:agent-runs"
+RUNTIME_NODE_RUN_QUEUE = "agent-platform:runtime-node-runs"
+SCHEDULED_JOB_QUEUE = "agent-platform:scheduled-jobs"
+TEAM_RUN_QUEUE = "agent-platform:team-runs"
+
+
+class TaskQueueConsumer(Protocol):
+    def dequeue(self, *, timeout_seconds: int = 1) -> dict[str, JSONValue] | None: ...
+
+
+class TaskQueuePublisher:
+    def __init__(self, *, client: "Redis") -> None:
+        self._client = client
+
+    def publish_agent_run(self, *, tenant_id: str, agent_run_id: str) -> bool:
+        return self._publish(
+            queue_name=AGENT_RUN_QUEUE,
+            payload={"tenant_id": tenant_id, "agent_run_id": agent_run_id},
+        )
+
+    def publish_runtime_node_run(self, *, tenant_id: str, node_run_id: str) -> bool:
+        return self._publish(
+            queue_name=RUNTIME_NODE_RUN_QUEUE,
+            payload={"tenant_id": tenant_id, "node_run_id": node_run_id},
+        )
+
+    def publish_scheduled_job(self, *, tenant_id: str, scheduled_job_id: str) -> bool:
+        return self._publish(
+            queue_name=SCHEDULED_JOB_QUEUE,
+            payload={"tenant_id": tenant_id, "scheduled_job_id": scheduled_job_id},
+        )
+
+    def publish_team_run(self, *, tenant_id: str, team_run_id: str) -> bool:
+        return self._publish(
+            queue_name=TEAM_RUN_QUEUE,
+            payload={"tenant_id": tenant_id, "team_run_id": team_run_id},
+        )
+
+    def _publish(self, *, queue_name: str, payload: dict[str, JSONValue]) -> bool:
+        try:
+            from core_shared.redis_primitives import RedisQueue
+
+            RedisQueue(client=self._client, name=queue_name).enqueue(payload)
+        except Exception:
+            return False
+        return True
+
+
+def build_task_queue_consumer(
+    *,
+    client: "Redis | None",
+    queue_name: str,
+) -> TaskQueueConsumer | None:
+    if client is None:
+        return None
+    try:
+        from core_shared.redis_primitives import RedisQueue
+    except Exception:
+        return None
+    return RedisQueue(client=client, name=queue_name)

+ 13 - 1
services/agent-service/app/application/services.py

@@ -15,7 +15,8 @@ from core_domain import (
     MemorySearchRequestContract,
     MemorySearchResultContract,
 )
-from core_shared import JSONValue
+from core_shared import JSONValue, try_build_redis_client
+from core_shared.task_queue import TaskQueuePublisher
 
 from app.bootstrap.settings import AgentServiceSettings
 from app.db.models import AgentDefinition, AgentRun, AgentToolInvocation, AgentVersion
@@ -52,6 +53,7 @@ class AgentApplicationService:
         tool_client: ToolServiceClient | None = None,
         skill_client: SkillServiceClient | None = None,
         event_client: EventServiceClient | None = None,
+        task_queue_publisher: TaskQueuePublisher | None = None,
         react_max_steps: int = 5,
         react_max_tool_calls: int = 10,
         react_tool_retry_count: int = 1,
@@ -65,6 +67,7 @@ class AgentApplicationService:
         self.tool_client = tool_client
         self.skill_client = skill_client
         self.event_client = event_client
+        self.task_queue_publisher = task_queue_publisher
         self.react_max_steps = react_max_steps
         self.react_max_tool_calls = react_max_tool_calls
         self.react_tool_retry_count = react_tool_retry_count
@@ -141,6 +144,11 @@ class AgentApplicationService:
             agent_run=agent_run,
             payload_json={"agent_run_id": agent_run.id, "status": agent_run.status},
         )
+        if self.task_queue_publisher is not None:
+            self.task_queue_publisher.publish_agent_run(
+                tenant_id=agent_run.tenant_id,
+                agent_run_id=agent_run.id,
+            )
         return agent_run
 
     def list_agent_runs(
@@ -1341,6 +1349,7 @@ def build_agent_application_service(
     db: Session,
     settings: AgentServiceSettings,
 ) -> AgentApplicationService:
+    redis_client = try_build_redis_client(settings.redis_url)
     return AgentApplicationService(
         agent_repository=AgentDefinitionRepository(db),
         agent_version_repository=AgentVersionRepository(db),
@@ -1366,5 +1375,8 @@ def build_agent_application_service(
             base_url=settings.event_service_url,
             timeout_seconds=settings.event_service_timeout_seconds,
         ),
+        task_queue_publisher=(
+            TaskQueuePublisher(client=redis_client) if redis_client is not None else None
+        ),
         react_max_steps=settings.react_max_steps,
     )

+ 19 - 1
services/agent-service/app/worker.py

@@ -5,11 +5,13 @@ import socket
 import time
 import traceback
 from dataclasses import dataclass
+from math import ceil
 from uuid import uuid4
 
 from sqlalchemy.orm import Session, sessionmaker
 
 from core_shared import try_build_redis_client
+from core_shared.task_queue import AGENT_RUN_QUEUE, build_task_queue_consumer
 
 from app.application.services import build_agent_application_service
 from app.bootstrap.settings import AgentServiceSettings
@@ -36,6 +38,10 @@ class AgentWorker:
         self.session_factory = session_factory
         self.worker_key = worker_key
         self.redis_client = try_build_redis_client(settings.redis_url)
+        self.task_queue = build_task_queue_consumer(
+            client=self.redis_client,
+            queue_name=AGENT_RUN_QUEUE,
+        )
 
     def run_forever(self) -> AgentWorkerStats:
         executed_count = 0
@@ -55,7 +61,8 @@ class AgentWorker:
                 idle_count = 0
             else:
                 idle_count += 1
-                time.sleep(self.settings.worker_poll_interval_seconds)
+                if self.task_queue is None:
+                    time.sleep(self.settings.worker_poll_interval_seconds)
 
             if self.settings.worker_max_idle_cycles is not None:
                 if idle_count >= self.settings.worker_max_idle_cycles:
@@ -67,6 +74,7 @@ class AgentWorker:
                     )
 
     def run_once(self) -> bool:
+        self._wait_for_queue_signal()
         db = self.session_factory()
         try:
             service = build_agent_application_service(db=db, settings=self.settings)
@@ -80,6 +88,16 @@ class AgentWorker:
         finally:
             db.close()
 
+    def _wait_for_queue_signal(self) -> None:
+        if self.task_queue is None:
+            return
+        try:
+            self.task_queue.dequeue(
+                timeout_seconds=max(1, ceil(self.settings.worker_poll_interval_seconds)),
+            )
+        except Exception:
+            return
+
 
 def build_worker_key() -> str:
     configured_key = os.getenv("AGENT_PLATFORM_WORKER_KEY")

+ 21 - 2
services/runtime-service/app/application/services.py

@@ -45,7 +45,8 @@ from app.schemas.run import (
     RunExecuteRequest,
     WorkflowRunStatusUpdateRequest,
 )
-from core_shared import JSONValue
+from core_shared import JSONValue, try_build_redis_client
+from core_shared.task_queue import TaskQueuePublisher
 
 
 class RuntimeApplicationService:
@@ -59,6 +60,7 @@ class RuntimeApplicationService:
         execution_dispatcher: NodeExecutionDispatcher,
         workflow_client: WorkflowServiceClient | None = None,
         event_client: EventServiceClient | None = None,
+        task_queue_publisher: TaskQueuePublisher | None = None,
     ) -> None:
         self.workflow_run_repository = workflow_run_repository
         self.node_run_repository = node_run_repository
@@ -68,6 +70,7 @@ class RuntimeApplicationService:
         self.execution_dispatcher = execution_dispatcher
         self.workflow_client = workflow_client
         self.event_client = event_client
+        self.task_queue_publisher = task_queue_publisher
 
     def create_run(self, payload: RunCreateRequest) -> tuple[WorkflowRun, NodeRun | None]:
         initial_node = payload.initial_node or self._plan_initial_node(payload)
@@ -118,6 +121,7 @@ class RuntimeApplicationService:
                     "status": initial_node.status,
                 },
             )
+            self._publish_node_run_to_queue(node_run)
 
         self._log_event(
             tenant_id=payload.tenant_id,
@@ -402,6 +406,7 @@ class RuntimeApplicationService:
             )
             if retried_node_run is None:
                 return None
+            self._publish_node_run_to_queue(retried_node_run)
             self.trace_span_repository.finish(
                 span_id=trace_span.id,
                 status="error",
@@ -732,7 +737,7 @@ class RuntimeApplicationService:
             ):
                 continue
             scheduled_time, timeout_time = self._build_node_timing(successor_config)
-            self.node_run_repository.create(
+            created = self.node_run_repository.create(
                 tenant_id=node_run.tenant_id,
                 run_id=node_run.run_id,
                 parent_node_run_id=node_run.id,
@@ -742,6 +747,7 @@ class RuntimeApplicationService:
                 scheduled_time=scheduled_time,
                 timeout_time=timeout_time,
             )
+            self._publish_node_run_to_queue(created)
             existing_node_counts[successor.node_id] = (
                 existing_node_counts.get(successor.node_id, 0) + 1
             )
@@ -958,6 +964,7 @@ class RuntimeApplicationService:
             scheduled_time=scheduled_time,
             timeout_time=timeout_time,
         )
+        self._publish_node_run_to_queue(created)
         self._log_event(
             tenant_id=node_run.tenant_id,
             run_id=node_run.run_id,
@@ -1143,12 +1150,21 @@ class RuntimeApplicationService:
         except EventServiceClientError:
             return
 
+    def _publish_node_run_to_queue(self, node_run: NodeRun) -> None:
+        if node_run.status != "queued" or self.task_queue_publisher is None:
+            return
+        self.task_queue_publisher.publish_runtime_node_run(
+            tenant_id=node_run.tenant_id,
+            node_run_id=node_run.id,
+        )
+
 
 def build_runtime_application_service(
     *,
     db: Session,
     settings: RuntimeServiceSettings,
 ) -> RuntimeApplicationService:
+    redis_client = try_build_redis_client(settings.redis_url)
     return RuntimeApplicationService(
         workflow_run_repository=WorkflowRunRepository(db),
         node_run_repository=NodeRunRepository(db),
@@ -1173,4 +1189,7 @@ def build_runtime_application_service(
             base_url=settings.event_service_url,
             timeout_seconds=settings.event_service_timeout_seconds,
         ),
+        task_queue_publisher=(
+            TaskQueuePublisher(client=redis_client) if redis_client is not None else None
+        ),
     )

+ 20 - 2
services/runtime-service/app/worker.py

@@ -5,11 +5,13 @@ import socket
 import time
 import traceback
 from dataclasses import dataclass
+from math import ceil
 from uuid import uuid4
 
-from sqlalchemy.orm import sessionmaker, Session
+from sqlalchemy.orm import Session, sessionmaker
 
 from core_shared import try_build_redis_client
+from core_shared.task_queue import RUNTIME_NODE_RUN_QUEUE, build_task_queue_consumer
 
 from app.application.services import build_runtime_application_service
 from app.bootstrap.settings import RuntimeServiceSettings
@@ -36,6 +38,10 @@ class RuntimeWorker:
         self.session_factory = session_factory
         self.worker_key = worker_key
         self.redis_client = try_build_redis_client(settings.redis_url)
+        self.task_queue = build_task_queue_consumer(
+            client=self.redis_client,
+            queue_name=RUNTIME_NODE_RUN_QUEUE,
+        )
 
     def run_forever(self) -> RuntimeWorkerStats:
         executed_count = 0
@@ -55,7 +61,8 @@ class RuntimeWorker:
                 idle_count = 0
             else:
                 idle_count += 1
-                time.sleep(self.settings.worker_poll_interval_seconds)
+                if self.task_queue is None:
+                    time.sleep(self.settings.worker_poll_interval_seconds)
 
             if self.settings.worker_max_idle_cycles is not None:
                 if idle_count >= self.settings.worker_max_idle_cycles:
@@ -67,6 +74,7 @@ class RuntimeWorker:
                     )
 
     def run_once(self) -> bool:
+        self._wait_for_queue_signal()
         db = self.session_factory()
         try:
             service = build_runtime_application_service(db=db, settings=self.settings)
@@ -79,6 +87,16 @@ class RuntimeWorker:
         finally:
             db.close()
 
+    def _wait_for_queue_signal(self) -> None:
+        if self.task_queue is None:
+            return
+        try:
+            self.task_queue.dequeue(
+                timeout_seconds=max(1, ceil(self.settings.worker_poll_interval_seconds)),
+            )
+        except Exception:
+            return
+
 
 def build_worker_key() -> str:
     configured_key = os.getenv("AGENT_PLATFORM_WORKER_KEY")

+ 13 - 2
services/scheduler-service/app/api/routes.py

@@ -1,10 +1,13 @@
-from fastapi import APIRouter, Depends, HTTPException, Query
+from fastapi import APIRouter, Depends, HTTPException, Query, Request
 from sqlalchemy import text
 from sqlalchemy.orm import Session
 
 from core_domain import ScheduledJobStatus, ScheduledJobType, ServiceHealth
+from core_shared import try_build_redis_client
+from core_shared.task_queue import TaskQueuePublisher
 
 from app.application.services import SchedulerApplicationService
+from app.bootstrap.settings import SchedulerServiceSettings
 from app.db.session import get_db
 from app.domain.repositories import ScheduledJobRepository
 from app.schemas.scheduler import (
@@ -18,9 +21,17 @@ router = APIRouter()
 
 
 def get_scheduler_application_service(
+    request: Request,
     db: Session = Depends(get_db),
 ) -> SchedulerApplicationService:
-    return SchedulerApplicationService(job_repository=ScheduledJobRepository(db))
+    settings: SchedulerServiceSettings = request.app.state.settings
+    redis_client = try_build_redis_client(settings.redis_url)
+    return SchedulerApplicationService(
+        job_repository=ScheduledJobRepository(db),
+        task_queue_publisher=(
+            TaskQueuePublisher(client=redis_client) if redis_client is not None else None
+        ),
+    )
 
 
 @router.get("/health", response_model=ServiceHealth)

+ 15 - 2
services/scheduler-service/app/application/services.py

@@ -1,6 +1,7 @@
 from datetime import datetime
 
 from core_domain import ScheduledJobStatus, ScheduledJobType
+from core_shared.task_queue import TaskQueuePublisher
 
 from app.db.models import ScheduledJob
 from app.domain.repositories import ScheduledJobRepository
@@ -12,11 +13,17 @@ from app.schemas.scheduler import (
 
 
 class SchedulerApplicationService:
-    def __init__(self, *, job_repository: ScheduledJobRepository) -> None:
+    def __init__(
+        self,
+        *,
+        job_repository: ScheduledJobRepository,
+        task_queue_publisher: TaskQueuePublisher | None = None,
+    ) -> None:
         self.job_repository = job_repository
+        self.task_queue_publisher = task_queue_publisher
 
     def create_job(self, payload: ScheduledJobCreateRequest) -> ScheduledJob:
-        return self.job_repository.create(
+        job = self.job_repository.create(
             tenant_id=payload.tenant_id,
             job_type=payload.job_type,
             name=payload.name,
@@ -29,6 +36,12 @@ class SchedulerApplicationService:
             max_attempts=payload.max_attempts,
             metadata_json=payload.metadata_json,
         )
+        if self.task_queue_publisher is not None:
+            self.task_queue_publisher.publish_scheduled_job(
+                tenant_id=job.tenant_id,
+                scheduled_job_id=job.id,
+            )
+        return job
 
     def list_jobs(
         self,

+ 19 - 1
services/scheduler-service/app/worker.py

@@ -6,12 +6,14 @@ import time
 import traceback
 from dataclasses import dataclass
 from datetime import datetime
+from math import ceil
 from uuid import uuid4
 
 import httpx
 from sqlalchemy.orm import Session, sessionmaker
 
 from core_shared import JSONValue, try_build_redis_client
+from core_shared.task_queue import SCHEDULED_JOB_QUEUE, build_task_queue_consumer
 
 from app.bootstrap.settings import SchedulerServiceSettings
 from app.db.models import ScheduledJob
@@ -87,6 +89,10 @@ class SchedulerWorker:
         self.worker_key = worker_key
         self.executor = ScheduledJobExecutor(settings=settings)
         self.redis_client = try_build_redis_client(settings.redis_url)
+        self.task_queue = build_task_queue_consumer(
+            client=self.redis_client,
+            queue_name=SCHEDULED_JOB_QUEUE,
+        )
 
     def run_forever(self) -> SchedulerWorkerStats:
         executed_count = 0
@@ -106,7 +112,8 @@ class SchedulerWorker:
                 idle_count = 0
             else:
                 idle_count += 1
-                time.sleep(self.settings.worker_poll_interval_seconds)
+                if self.task_queue is None:
+                    time.sleep(self.settings.worker_poll_interval_seconds)
 
             if self.settings.worker_max_idle_cycles is not None:
                 if idle_count >= self.settings.worker_max_idle_cycles:
@@ -118,6 +125,7 @@ class SchedulerWorker:
                     )
 
     def run_once(self) -> bool:
+        self._wait_for_queue_signal()
         db = self.session_factory()
         try:
             repository = ScheduledJobRepository(db)
@@ -134,6 +142,16 @@ class SchedulerWorker:
         finally:
             db.close()
 
+    def _wait_for_queue_signal(self) -> None:
+        if self.task_queue is None:
+            return
+        try:
+            self.task_queue.dequeue(
+                timeout_seconds=max(1, ceil(self.settings.worker_poll_interval_seconds)),
+            )
+        except Exception:
+            return
+
     def _execute_claimed_job(
         self,
         *,

+ 13 - 1
services/team-service/app/application/services.py

@@ -2,7 +2,8 @@ from datetime import datetime, timedelta
 
 from core_events import EventPublishContract, EventServiceClient, EventServiceClientError
 from core_domain import AgentRunContract, TeamMemberContract
-from core_shared import JSONValue
+from core_shared import JSONValue, try_build_redis_client
+from core_shared.task_queue import TaskQueuePublisher
 
 from app.bootstrap.settings import TeamServiceSettings
 from app.db.models import TeamDefinition, TeamRun, TeamVersion
@@ -31,12 +32,14 @@ class TeamApplicationService:
         team_run_repository: TeamRunRepository,
         agent_client: AgentServiceClient | None = None,
         event_client: EventServiceClient | None = None,
+        task_queue_publisher: TaskQueuePublisher | None = None,
     ) -> None:
         self.team_repository = team_repository
         self.team_version_repository = team_version_repository
         self.team_run_repository = team_run_repository
         self.agent_client = agent_client
         self.event_client = event_client
+        self.task_queue_publisher = task_queue_publisher
 
     def create_team(self, payload: TeamCreateRequest) -> TeamDefinition:
         return self.team_repository.create(
@@ -106,6 +109,11 @@ class TeamApplicationService:
             team_run=team_run,
             payload_json={"team_run_id": team_run.id, "status": team_run.status},
         )
+        if self.task_queue_publisher is not None:
+            self.task_queue_publisher.publish_team_run(
+                tenant_id=team_run.tenant_id,
+                team_run_id=team_run.id,
+            )
         return team_run
 
     def list_team_runs(
@@ -485,6 +493,7 @@ def build_team_application_service(
     team_run_repository: TeamRunRepository,
     settings: TeamServiceSettings,
 ) -> TeamApplicationService:
+    redis_client = try_build_redis_client(settings.redis_url)
     return TeamApplicationService(
         team_repository=team_repository,
         team_version_repository=team_version_repository,
@@ -497,4 +506,7 @@ def build_team_application_service(
             base_url=settings.event_service_url,
             timeout_seconds=settings.event_service_timeout_seconds,
         ),
+        task_queue_publisher=(
+            TaskQueuePublisher(client=redis_client) if redis_client is not None else None
+        ),
     )

+ 22 - 1
services/team-service/app/worker.py

@@ -5,10 +5,14 @@ import socket
 import time
 import traceback
 from dataclasses import dataclass
+from math import ceil
 from uuid import uuid4
 
 from sqlalchemy.orm import Session, sessionmaker
 
+from core_shared import try_build_redis_client
+from core_shared.task_queue import TEAM_RUN_QUEUE, build_task_queue_consumer
+
 from app.application.services import build_team_application_service
 from app.bootstrap.settings import TeamServiceSettings
 from app.db.session import build_session_factory
@@ -38,6 +42,11 @@ class TeamWorker:
         self.settings = settings
         self.session_factory = session_factory
         self.worker_key = worker_key
+        self.redis_client = try_build_redis_client(settings.redis_url)
+        self.task_queue = build_task_queue_consumer(
+            client=self.redis_client,
+            queue_name=TEAM_RUN_QUEUE,
+        )
 
     def run_forever(self) -> TeamWorkerStats:
         executed_count = 0
@@ -57,7 +66,8 @@ class TeamWorker:
                 idle_count = 0
             else:
                 idle_count += 1
-                time.sleep(self.settings.worker_poll_interval_seconds)
+                if self.task_queue is None:
+                    time.sleep(self.settings.worker_poll_interval_seconds)
 
             if self.settings.worker_max_idle_cycles is not None:
                 if idle_count >= self.settings.worker_max_idle_cycles:
@@ -69,6 +79,7 @@ class TeamWorker:
                     )
 
     def run_once(self) -> bool:
+        self._wait_for_queue_signal()
         db = self.session_factory()
         try:
             service = build_team_application_service(
@@ -86,6 +97,16 @@ class TeamWorker:
         finally:
             db.close()
 
+    def _wait_for_queue_signal(self) -> None:
+        if self.task_queue is None:
+            return
+        try:
+            self.task_queue.dequeue(
+                timeout_seconds=max(1, ceil(self.settings.worker_poll_interval_seconds)),
+            )
+        except Exception:
+            return
+
 
 def build_worker_key() -> str:
     configured_key = os.getenv("AGENT_PLATFORM_WORKER_KEY")