test_runtime_debugger.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from __future__ import annotations
  2. from collections.abc import Generator
  3. from datetime import datetime
  4. from pathlib import Path
  5. from sqlalchemy.orm import Session
  6. from tests.conftest import (
  7. build_fastapi_test_client,
  8. build_postgres_database_url,
  9. prepare_known_service_import,
  10. )
  11. def test_runtime_debugger_pause_step_and_breakpoint_continue(tmp_path: Path) -> None:
  12. prepare_known_service_import("runtime-service")
  13. from app.api.routes import get_runtime_application_service
  14. from app.application.services import RuntimeApplicationService
  15. from app.bootstrap.app import create_app
  16. from app.bootstrap.settings import RuntimeServiceSettings
  17. from app.db.session import build_session_factory
  18. from app.domain.repositories import (
  19. ExecutionLogRepository,
  20. NodeArtifactRepository,
  21. NodeRunRepository,
  22. TraceSpanRepository,
  23. WorkflowRunRepository,
  24. )
  25. from app.infrastructure.executors import build_node_execution_dispatcher
  26. from core_db import Base
  27. from core_domain import WorkflowConfigContract
  28. class FakeWorkflowClient:
  29. def get_workflow_config(
  30. self,
  31. *,
  32. workflow_config_id: str) -> WorkflowConfigContract:
  33. assert workflow_config_id == "wc1"
  34. return WorkflowConfigContract(
  35. id="wc1",
  36. workflow_id="wf1",
  37. created_time=datetime.utcnow(),
  38. dsl_json={
  39. "code": "debug_flow",
  40. "nodes": [
  41. {"id": "start", "type": "template", "config": {"template": "hello"}},
  42. {"id": "answer", "type": "answer", "config": {"text": "done"}},
  43. ],
  44. "edges": [{"source": "start", "target": "answer"}],
  45. })
  46. settings = RuntimeServiceSettings(
  47. database_url=build_postgres_database_url(tmp_path, "runtime-debugger"))
  48. session_factory = build_session_factory(settings)
  49. engine = session_factory.kw["bind"]
  50. Base.metadata.create_all(bind=engine)
  51. def override_service() -> Generator[RuntimeApplicationService, None, None]:
  52. db: Session = session_factory()
  53. try:
  54. yield RuntimeApplicationService(
  55. workflow_run_repository=WorkflowRunRepository(db),
  56. node_run_repository=NodeRunRepository(db),
  57. execution_log_repository=ExecutionLogRepository(db),
  58. node_artifact_repository=NodeArtifactRepository(db),
  59. trace_span_repository=TraceSpanRepository(db),
  60. execution_dispatcher=build_node_execution_dispatcher(),
  61. workflow_client=FakeWorkflowClient())
  62. finally:
  63. db.close()
  64. app = create_app()
  65. app.state.session_factory = session_factory
  66. app.dependency_overrides[get_runtime_application_service] = override_service
  67. client = build_fastapi_test_client(app)
  68. create_response = client.post(
  69. "/runtime/runs",
  70. json={
  71. "app_id": "app1",
  72. "app_config_id": "ac1",
  73. "workflow_id": "wf1",
  74. "workflow_config_id": "wc1",
  75. })
  76. assert create_response.status_code == 200
  77. run_id = create_response.json()["run"]["id"]
  78. pause_response = client.post(
  79. f"/runtime/runs/{run_id}/debug/pause")
  80. assert pause_response.status_code == 200
  81. assert pause_response.json()["run"]["status"] == "paused"
  82. assert pause_response.json()["queued_node_ids"] == ["start"]
  83. protected_execute_response = client.post(
  84. f"/runtime/runs/{run_id}/execute-next",
  85. json={"worker_key": "debugger"})
  86. assert protected_execute_response.status_code == 200
  87. assert protected_execute_response.json()["executor_name"] == "debug_paused"
  88. assert protected_execute_response.json()["node_run"]["status"] == "queued"
  89. step_response = client.post(
  90. f"/runtime/runs/{run_id}/debug/step",
  91. json={"worker_key": "debugger"})
  92. assert step_response.status_code == 200
  93. step_payload = step_response.json()
  94. assert step_payload["reason"] == "step_completed"
  95. assert [item["node_id"] for item in step_payload["executed_node_runs"]] == ["start"]
  96. assert step_payload["snapshot"]["run"]["status"] == "paused"
  97. assert step_payload["snapshot"]["completed_node_ids"] == ["start"]
  98. assert step_payload["snapshot"]["queued_node_ids"] == ["answer"]
  99. breakpoint_response = client.post(
  100. f"/runtime/runs/{run_id}/debug/continue",
  101. json={"worker_key": "debugger", "breakpoint_node_ids": ["answer"], "max_steps": 5})
  102. assert breakpoint_response.status_code == 200
  103. breakpoint_payload = breakpoint_response.json()
  104. assert breakpoint_payload["reason"] == "breakpoint_hit"
  105. assert breakpoint_payload["paused_before_node_id"] == "answer"
  106. assert breakpoint_payload["executed_node_runs"] == []
  107. assert breakpoint_payload["snapshot"]["queued_node_ids"] == ["answer"]
  108. finish_response = client.post(
  109. f"/runtime/runs/{run_id}/debug/continue",
  110. json={"worker_key": "debugger", "max_steps": 5})
  111. assert finish_response.status_code == 200
  112. finish_payload = finish_response.json()
  113. assert finish_payload["snapshot"]["run"]["status"] == "completed"
  114. assert finish_payload["snapshot"]["completed_node_ids"] == ["start", "answer"]
  115. assert finish_payload["snapshot"]["queued_node_ids"] == []
  116. engine.dispose()