test_runtime_debugger.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from __future__ import annotations
  2. from collections.abc import Generator
  3. from datetime import datetime
  4. from pathlib import Path
  5. from sqlalchemy.orm import Session
  6. from tests.conftest import (
  7. build_fastapi_test_client,
  8. build_sqlite_database_url,
  9. prepare_known_service_import,
  10. )
  11. def test_runtime_debugger_pause_step_and_breakpoint_continue(tmp_path: Path) -> None:
  12. prepare_known_service_import("runtime-service")
  13. from app.api.routes import get_runtime_application_service
  14. from app.application.services import RuntimeApplicationService
  15. from app.bootstrap.app import create_app
  16. from app.bootstrap.settings import RuntimeServiceSettings
  17. from app.db.session import build_session_factory
  18. from app.domain.repositories import (
  19. ExecutionLogRepository,
  20. NodeArtifactRepository,
  21. NodeRunRepository,
  22. TraceSpanRepository,
  23. WorkflowRunRepository,
  24. )
  25. from app.infrastructure.executors import build_node_execution_dispatcher
  26. from core_db import Base
  27. from core_domain import WorkflowVersionContract
  28. class FakeWorkflowClient:
  29. def get_workflow_version(
  30. self,
  31. *,
  32. tenant_id: str,
  33. workflow_version_id: str,
  34. ) -> WorkflowVersionContract:
  35. assert tenant_id == "t1"
  36. assert workflow_version_id == "wv1"
  37. return WorkflowVersionContract(
  38. id="wv1",
  39. tenant_id="t1",
  40. workflow_id="wf1",
  41. version_no=1,
  42. status="published",
  43. created_time=datetime.utcnow(),
  44. dsl_json={
  45. "code": "debug_flow",
  46. "nodes": [
  47. {"id": "start", "type": "template", "config": {"template": "hello"}},
  48. {"id": "answer", "type": "answer", "config": {"text": "done"}},
  49. ],
  50. "edges": [{"source": "start", "target": "answer"}],
  51. },
  52. )
  53. settings = RuntimeServiceSettings(database_url=build_sqlite_database_url(tmp_path, "runtime.db"))
  54. session_factory = build_session_factory(settings)
  55. engine = session_factory.kw["bind"]
  56. Base.metadata.create_all(bind=engine)
  57. def override_service() -> Generator[RuntimeApplicationService, None, None]:
  58. db: Session = session_factory()
  59. try:
  60. yield RuntimeApplicationService(
  61. workflow_run_repository=WorkflowRunRepository(db),
  62. node_run_repository=NodeRunRepository(db),
  63. execution_log_repository=ExecutionLogRepository(db),
  64. node_artifact_repository=NodeArtifactRepository(db),
  65. trace_span_repository=TraceSpanRepository(db),
  66. execution_dispatcher=build_node_execution_dispatcher(),
  67. workflow_client=FakeWorkflowClient(),
  68. )
  69. finally:
  70. db.close()
  71. app = create_app()
  72. app.state.session_factory = session_factory
  73. app.dependency_overrides[get_runtime_application_service] = override_service
  74. client = build_fastapi_test_client(app)
  75. create_response = client.post(
  76. "/runtime/runs",
  77. json={
  78. "tenant_id": "t1",
  79. "app_id": "app1",
  80. "app_version_id": "av1",
  81. "workflow_id": "wf1",
  82. "workflow_version_id": "wv1",
  83. },
  84. )
  85. assert create_response.status_code == 200
  86. run_id = create_response.json()["run"]["id"]
  87. pause_response = client.post(
  88. f"/runtime/runs/{run_id}/debug/pause",
  89. params={"tenant_id": "t1"},
  90. )
  91. assert pause_response.status_code == 200
  92. assert pause_response.json()["run"]["status"] == "paused"
  93. assert pause_response.json()["queued_node_ids"] == ["start"]
  94. protected_execute_response = client.post(
  95. f"/runtime/runs/{run_id}/execute-next",
  96. params={"tenant_id": "t1"},
  97. json={"worker_key": "debugger"},
  98. )
  99. assert protected_execute_response.status_code == 200
  100. assert protected_execute_response.json()["executor_name"] == "debug_paused"
  101. assert protected_execute_response.json()["node_run"]["status"] == "queued"
  102. step_response = client.post(
  103. f"/runtime/runs/{run_id}/debug/step",
  104. params={"tenant_id": "t1"},
  105. json={"worker_key": "debugger"},
  106. )
  107. assert step_response.status_code == 200
  108. step_payload = step_response.json()
  109. assert step_payload["reason"] == "step_completed"
  110. assert [item["node_id"] for item in step_payload["executed_node_runs"]] == ["start"]
  111. assert step_payload["snapshot"]["run"]["status"] == "paused"
  112. assert step_payload["snapshot"]["completed_node_ids"] == ["start"]
  113. assert step_payload["snapshot"]["queued_node_ids"] == ["answer"]
  114. breakpoint_response = client.post(
  115. f"/runtime/runs/{run_id}/debug/continue",
  116. params={"tenant_id": "t1"},
  117. json={"worker_key": "debugger", "breakpoint_node_ids": ["answer"], "max_steps": 5},
  118. )
  119. assert breakpoint_response.status_code == 200
  120. breakpoint_payload = breakpoint_response.json()
  121. assert breakpoint_payload["reason"] == "breakpoint_hit"
  122. assert breakpoint_payload["paused_before_node_id"] == "answer"
  123. assert breakpoint_payload["executed_node_runs"] == []
  124. assert breakpoint_payload["snapshot"]["queued_node_ids"] == ["answer"]
  125. finish_response = client.post(
  126. f"/runtime/runs/{run_id}/debug/continue",
  127. params={"tenant_id": "t1"},
  128. json={"worker_key": "debugger", "max_steps": 5},
  129. )
  130. assert finish_response.status_code == 200
  131. finish_payload = finish_response.json()
  132. assert finish_payload["snapshot"]["run"]["status"] == "completed"
  133. assert finish_payload["snapshot"]["completed_node_ids"] == ["start", "answer"]
  134. assert finish_payload["snapshot"]["queued_node_ids"] == []
  135. engine.dispose()