test_runtime_debugger.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from __future__ import annotations
  2. from collections.abc import Generator
  3. from datetime import datetime
  4. from pathlib import Path
  5. from sqlalchemy.orm import Session
  6. from tests.conftest import (
  7. build_fastapi_test_client,
  8. build_sqlite_database_url,
  9. prepare_known_service_import,
  10. )
  11. def test_runtime_debugger_pause_step_and_breakpoint_continue(tmp_path: Path) -> None:
  12. prepare_known_service_import("runtime-service")
  13. from app.api.routes import get_runtime_application_service
  14. from app.application.services import RuntimeApplicationService
  15. from app.bootstrap.app import create_app
  16. from app.bootstrap.settings import RuntimeServiceSettings
  17. from app.db.session import build_session_factory
  18. from app.domain.repositories import (
  19. ExecutionLogRepository,
  20. NodeArtifactRepository,
  21. NodeRunRepository,
  22. TraceSpanRepository,
  23. WorkflowRunRepository,
  24. )
  25. from app.infrastructure.executors import build_node_execution_dispatcher
  26. from core_db import Base
  27. from core_domain import WorkflowVersionContract
  28. class FakeWorkflowClient:
  29. def get_workflow_version(
  30. self,
  31. *,
  32. workflow_version_id: str) -> WorkflowVersionContract:
  33. assert workflow_version_id == "wv1"
  34. return WorkflowVersionContract(
  35. id="wv1",
  36. workflow_id="wf1",
  37. version_no=1,
  38. status="published",
  39. created_time=datetime.utcnow(),
  40. dsl_json={
  41. "code": "debug_flow",
  42. "nodes": [
  43. {"id": "start", "type": "template", "config": {"template": "hello"}},
  44. {"id": "answer", "type": "answer", "config": {"text": "done"}},
  45. ],
  46. "edges": [{"source": "start", "target": "answer"}],
  47. })
  48. settings = RuntimeServiceSettings(database_url=build_sqlite_database_url(tmp_path, "runtime.db"))
  49. session_factory = build_session_factory(settings)
  50. engine = session_factory.kw["bind"]
  51. Base.metadata.create_all(bind=engine)
  52. def override_service() -> Generator[RuntimeApplicationService, None, None]:
  53. db: Session = session_factory()
  54. try:
  55. yield RuntimeApplicationService(
  56. workflow_run_repository=WorkflowRunRepository(db),
  57. node_run_repository=NodeRunRepository(db),
  58. execution_log_repository=ExecutionLogRepository(db),
  59. node_artifact_repository=NodeArtifactRepository(db),
  60. trace_span_repository=TraceSpanRepository(db),
  61. execution_dispatcher=build_node_execution_dispatcher(),
  62. workflow_client=FakeWorkflowClient())
  63. finally:
  64. db.close()
  65. app = create_app()
  66. app.state.session_factory = session_factory
  67. app.dependency_overrides[get_runtime_application_service] = override_service
  68. client = build_fastapi_test_client(app)
  69. create_response = client.post(
  70. "/runtime/runs",
  71. json={
  72. "app_id": "app1",
  73. "app_version_id": "av1",
  74. "workflow_id": "wf1",
  75. "workflow_version_id": "wv1",
  76. })
  77. assert create_response.status_code == 200
  78. run_id = create_response.json()["run"]["id"]
  79. pause_response = client.post(
  80. f"/runtime/runs/{run_id}/debug/pause")
  81. assert pause_response.status_code == 200
  82. assert pause_response.json()["run"]["status"] == "paused"
  83. assert pause_response.json()["queued_node_ids"] == ["start"]
  84. protected_execute_response = client.post(
  85. f"/runtime/runs/{run_id}/execute-next",
  86. json={"worker_key": "debugger"})
  87. assert protected_execute_response.status_code == 200
  88. assert protected_execute_response.json()["executor_name"] == "debug_paused"
  89. assert protected_execute_response.json()["node_run"]["status"] == "queued"
  90. step_response = client.post(
  91. f"/runtime/runs/{run_id}/debug/step",
  92. json={"worker_key": "debugger"})
  93. assert step_response.status_code == 200
  94. step_payload = step_response.json()
  95. assert step_payload["reason"] == "step_completed"
  96. assert [item["node_id"] for item in step_payload["executed_node_runs"]] == ["start"]
  97. assert step_payload["snapshot"]["run"]["status"] == "paused"
  98. assert step_payload["snapshot"]["completed_node_ids"] == ["start"]
  99. assert step_payload["snapshot"]["queued_node_ids"] == ["answer"]
  100. breakpoint_response = client.post(
  101. f"/runtime/runs/{run_id}/debug/continue",
  102. json={"worker_key": "debugger", "breakpoint_node_ids": ["answer"], "max_steps": 5})
  103. assert breakpoint_response.status_code == 200
  104. breakpoint_payload = breakpoint_response.json()
  105. assert breakpoint_payload["reason"] == "breakpoint_hit"
  106. assert breakpoint_payload["paused_before_node_id"] == "answer"
  107. assert breakpoint_payload["executed_node_runs"] == []
  108. assert breakpoint_payload["snapshot"]["queued_node_ids"] == ["answer"]
  109. finish_response = client.post(
  110. f"/runtime/runs/{run_id}/debug/continue",
  111. json={"worker_key": "debugger", "max_steps": 5})
  112. assert finish_response.status_code == 200
  113. finish_payload = finish_response.json()
  114. assert finish_payload["snapshot"]["run"]["status"] == "completed"
  115. assert finish_payload["snapshot"]["completed_node_ids"] == ["start", "answer"]
  116. assert finish_payload["snapshot"]["queued_node_ids"] == []
  117. engine.dispose()