test_runtime_debugger.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from __future__ import annotations
  2. from collections.abc import Generator
  3. from datetime import datetime
  4. from pathlib import Path
  5. from sqlalchemy.orm import Session
  6. from tests.conftest import build_fastapi_test_client, prepare_service_import
  7. def test_runtime_debugger_pause_step_and_breakpoint_continue(tmp_path: Path) -> None:
  8. prepare_service_import(
  9. "runtime-service",
  10. libs=("core-domain", "core-shared", "core-db", "core-events", "core-dsl"),
  11. )
  12. from app.api.routes import get_runtime_application_service
  13. from app.application.services import RuntimeApplicationService
  14. from app.bootstrap.app import create_app
  15. from app.bootstrap.settings import RuntimeServiceSettings
  16. from app.db.session import build_session_factory
  17. from app.domain.repositories import (
  18. ExecutionLogRepository,
  19. NodeArtifactRepository,
  20. NodeRunRepository,
  21. TraceSpanRepository,
  22. WorkflowRunRepository,
  23. )
  24. from app.infrastructure.executors import build_node_execution_dispatcher
  25. from core_db import Base
  26. from core_domain import WorkflowVersionContract
  27. class FakeWorkflowClient:
  28. def get_workflow_version(
  29. self,
  30. *,
  31. tenant_id: str,
  32. workflow_version_id: str,
  33. ) -> WorkflowVersionContract:
  34. assert tenant_id == "t1"
  35. assert workflow_version_id == "wv1"
  36. return WorkflowVersionContract(
  37. id="wv1",
  38. tenant_id="t1",
  39. workflow_id="wf1",
  40. version_no=1,
  41. status="published",
  42. created_time=datetime.utcnow(),
  43. dsl_json={
  44. "code": "debug_flow",
  45. "nodes": [
  46. {"id": "start", "type": "template", "config": {"template": "hello"}},
  47. {"id": "answer", "type": "answer", "config": {"text": "done"}},
  48. ],
  49. "edges": [{"source": "start", "target": "answer"}],
  50. },
  51. )
  52. settings = RuntimeServiceSettings(database_url=f"sqlite:///{tmp_path / 'runtime.db'}")
  53. session_factory = build_session_factory(settings)
  54. engine = session_factory.kw["bind"]
  55. Base.metadata.create_all(bind=engine)
  56. def override_service() -> Generator[RuntimeApplicationService, None, None]:
  57. db: Session = session_factory()
  58. try:
  59. yield RuntimeApplicationService(
  60. workflow_run_repository=WorkflowRunRepository(db),
  61. node_run_repository=NodeRunRepository(db),
  62. execution_log_repository=ExecutionLogRepository(db),
  63. node_artifact_repository=NodeArtifactRepository(db),
  64. trace_span_repository=TraceSpanRepository(db),
  65. execution_dispatcher=build_node_execution_dispatcher(),
  66. workflow_client=FakeWorkflowClient(),
  67. )
  68. finally:
  69. db.close()
  70. app = create_app()
  71. app.state.session_factory = session_factory
  72. app.dependency_overrides[get_runtime_application_service] = override_service
  73. client = build_fastapi_test_client(app)
  74. create_response = client.post(
  75. "/runtime/runs",
  76. json={
  77. "tenant_id": "t1",
  78. "app_id": "app1",
  79. "app_version_id": "av1",
  80. "workflow_id": "wf1",
  81. "workflow_version_id": "wv1",
  82. },
  83. )
  84. assert create_response.status_code == 200
  85. run_id = create_response.json()["run"]["id"]
  86. pause_response = client.post(
  87. f"/runtime/runs/{run_id}/debug/pause",
  88. params={"tenant_id": "t1"},
  89. )
  90. assert pause_response.status_code == 200
  91. assert pause_response.json()["run"]["status"] == "paused"
  92. assert pause_response.json()["queued_node_ids"] == ["start"]
  93. protected_execute_response = client.post(
  94. f"/runtime/runs/{run_id}/execute-next",
  95. params={"tenant_id": "t1"},
  96. json={"worker_key": "debugger"},
  97. )
  98. assert protected_execute_response.status_code == 200
  99. assert protected_execute_response.json()["executor_name"] == "debug_paused"
  100. assert protected_execute_response.json()["node_run"]["status"] == "queued"
  101. step_response = client.post(
  102. f"/runtime/runs/{run_id}/debug/step",
  103. params={"tenant_id": "t1"},
  104. json={"worker_key": "debugger"},
  105. )
  106. assert step_response.status_code == 200
  107. step_payload = step_response.json()
  108. assert step_payload["reason"] == "step_completed"
  109. assert [item["node_id"] for item in step_payload["executed_node_runs"]] == ["start"]
  110. assert step_payload["snapshot"]["run"]["status"] == "paused"
  111. assert step_payload["snapshot"]["completed_node_ids"] == ["start"]
  112. assert step_payload["snapshot"]["queued_node_ids"] == ["answer"]
  113. breakpoint_response = client.post(
  114. f"/runtime/runs/{run_id}/debug/continue",
  115. params={"tenant_id": "t1"},
  116. json={"worker_key": "debugger", "breakpoint_node_ids": ["answer"], "max_steps": 5},
  117. )
  118. assert breakpoint_response.status_code == 200
  119. breakpoint_payload = breakpoint_response.json()
  120. assert breakpoint_payload["reason"] == "breakpoint_hit"
  121. assert breakpoint_payload["paused_before_node_id"] == "answer"
  122. assert breakpoint_payload["executed_node_runs"] == []
  123. assert breakpoint_payload["snapshot"]["queued_node_ids"] == ["answer"]
  124. finish_response = client.post(
  125. f"/runtime/runs/{run_id}/debug/continue",
  126. params={"tenant_id": "t1"},
  127. json={"worker_key": "debugger", "max_steps": 5},
  128. )
  129. assert finish_response.status_code == 200
  130. finish_payload = finish_response.json()
  131. assert finish_payload["snapshot"]["run"]["status"] == "completed"
  132. assert finish_payload["snapshot"]["completed_node_ids"] == ["start", "answer"]
  133. assert finish_payload["snapshot"]["queued_node_ids"] == []
  134. engine.dispose()