designer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Literal
  4. from core_dsl import WorkflowDefinition, parse_workflow_definition
  5. from core_shared import JSONValue
  6. from pydantic import ValidationError
  7. DiagnosticSeverity = Literal["error", "warning", "info"]
  8. @dataclass(frozen=True)
  9. class WorkflowDiagnostic:
  10. severity: DiagnosticSeverity
  11. code: str
  12. message: str
  13. node_id: str | None = None
  14. edge_index: int | None = None
  15. @dataclass(frozen=True)
  16. class WorkflowNodeInspection:
  17. id: str
  18. type: str
  19. name: str | None
  20. incoming_count: int
  21. outgoing_count: int
  22. reachable: bool
  23. @dataclass(frozen=True)
  24. class WorkflowEdgeInspection:
  25. source: str
  26. target: str
  27. condition: str | None
  28. valid_source: bool
  29. valid_target: bool
  30. @dataclass(frozen=True)
  31. class WorkflowDebugStep:
  32. step_index: int
  33. node_id: str
  34. node_type: str
  35. name: str | None
  36. next_node_ids: list[str]
  37. @dataclass(frozen=True)
  38. class WorkflowInspection:
  39. valid: bool
  40. diagnostics: list[WorkflowDiagnostic]
  41. workflow: WorkflowDefinition | None
  42. nodes: list[WorkflowNodeInspection]
  43. edges: list[WorkflowEdgeInspection]
  44. entry_node_ids: list[str]
  45. terminal_node_ids: list[str]
  46. isolated_node_ids: list[str]
  47. unreachable_node_ids: list[str]
  48. cycle_detected: bool
  49. @dataclass(frozen=True)
  50. class WorkflowDebugPlan:
  51. inspection: WorkflowInspection
  52. execution_preview: list[WorkflowDebugStep]
  53. max_preview_steps: int
  54. truncated: bool
  55. def inspect_workflow_dsl(payload: dict[str, JSONValue] | None) -> WorkflowInspection:
  56. if payload is None:
  57. return WorkflowInspection(
  58. valid=False,
  59. diagnostics=[
  60. WorkflowDiagnostic(
  61. severity="error",
  62. code="dsl.required",
  63. message="workflow dsl_json is required")
  64. ],
  65. workflow=None,
  66. nodes=[],
  67. edges=[],
  68. entry_node_ids=[],
  69. terminal_node_ids=[],
  70. isolated_node_ids=[],
  71. unreachable_node_ids=[],
  72. cycle_detected=False)
  73. try:
  74. workflow = parse_workflow_definition(payload)
  75. except ValidationError as exc:
  76. return WorkflowInspection(
  77. valid=False,
  78. diagnostics=[
  79. WorkflowDiagnostic(
  80. severity="error",
  81. code="dsl.schema_invalid",
  82. message=str(exc))
  83. ],
  84. workflow=None,
  85. nodes=[],
  86. edges=[],
  87. entry_node_ids=[],
  88. terminal_node_ids=[],
  89. isolated_node_ids=[],
  90. unreachable_node_ids=[],
  91. cycle_detected=False)
  92. if workflow is None:
  93. return WorkflowInspection(
  94. valid=False,
  95. diagnostics=[
  96. WorkflowDiagnostic(
  97. severity="error",
  98. code="dsl.required",
  99. message="workflow dsl_json is required")
  100. ],
  101. workflow=None,
  102. nodes=[],
  103. edges=[],
  104. entry_node_ids=[],
  105. terminal_node_ids=[],
  106. isolated_node_ids=[],
  107. unreachable_node_ids=[],
  108. cycle_detected=False)
  109. diagnostics: list[WorkflowDiagnostic] = []
  110. node_ids = [node.id for node in workflow.nodes]
  111. node_id_set = set(node_ids)
  112. duplicate_node_ids = sorted({node_id for node_id in node_ids if node_ids.count(node_id) > 1})
  113. for node_id in duplicate_node_ids:
  114. diagnostics.append(
  115. WorkflowDiagnostic(
  116. severity="error",
  117. code="node.duplicate_id",
  118. message=f"duplicate node id: {node_id}",
  119. node_id=node_id)
  120. )
  121. incoming_counts = {node_id: 0 for node_id in node_ids}
  122. outgoing_counts = {node_id: 0 for node_id in node_ids}
  123. adjacency: dict[str, list[str]] = {node_id: [] for node_id in node_ids}
  124. edge_inspections: list[WorkflowEdgeInspection] = []
  125. for edge_index, edge in enumerate(workflow.edges):
  126. valid_source = edge.source in node_id_set
  127. valid_target = edge.target in node_id_set
  128. edge_inspections.append(
  129. WorkflowEdgeInspection(
  130. source=edge.source,
  131. target=edge.target,
  132. condition=edge.condition,
  133. valid_source=valid_source,
  134. valid_target=valid_target)
  135. )
  136. if not valid_source:
  137. diagnostics.append(
  138. WorkflowDiagnostic(
  139. severity="error",
  140. code="edge.source_missing",
  141. message=f"edge source node does not exist: {edge.source}",
  142. node_id=edge.source,
  143. edge_index=edge_index)
  144. )
  145. if not valid_target:
  146. diagnostics.append(
  147. WorkflowDiagnostic(
  148. severity="error",
  149. code="edge.target_missing",
  150. message=f"edge target node does not exist: {edge.target}",
  151. node_id=edge.target,
  152. edge_index=edge_index)
  153. )
  154. if valid_source and valid_target:
  155. outgoing_counts[edge.source] = outgoing_counts.get(edge.source, 0) + 1
  156. incoming_counts[edge.target] = incoming_counts.get(edge.target, 0) + 1
  157. adjacency.setdefault(edge.source, []).append(edge.target)
  158. if not workflow.nodes:
  159. diagnostics.append(
  160. WorkflowDiagnostic(
  161. severity="error",
  162. code="workflow.nodes_required",
  163. message="workflow must contain at least one node")
  164. )
  165. entry_node_ids = [node_id for node_id in node_ids if incoming_counts.get(node_id, 0) == 0]
  166. terminal_node_ids = [node_id for node_id in node_ids if outgoing_counts.get(node_id, 0) == 0]
  167. isolated_node_ids = [
  168. node_id
  169. for node_id in node_ids
  170. if incoming_counts.get(node_id, 0) == 0 and outgoing_counts.get(node_id, 0) == 0
  171. ]
  172. if len(entry_node_ids) > 1:
  173. diagnostics.append(
  174. WorkflowDiagnostic(
  175. severity="warning",
  176. code="workflow.multiple_entry_nodes",
  177. message=f"workflow has multiple entry nodes: {', '.join(entry_node_ids)}")
  178. )
  179. if not terminal_node_ids and workflow.nodes:
  180. diagnostics.append(
  181. WorkflowDiagnostic(
  182. severity="warning",
  183. code="workflow.no_terminal_node",
  184. message="workflow has no terminal node")
  185. )
  186. reachable_node_ids = _find_reachable_nodes(entry_node_ids, adjacency)
  187. unreachable_node_ids = [node_id for node_id in node_ids if node_id not in reachable_node_ids]
  188. for node_id in unreachable_node_ids:
  189. diagnostics.append(
  190. WorkflowDiagnostic(
  191. severity="warning",
  192. code="node.unreachable",
  193. message=f"node is not reachable from an entry node: {node_id}",
  194. node_id=node_id)
  195. )
  196. cycle_detected = _detect_cycle(node_ids, adjacency)
  197. if cycle_detected:
  198. diagnostics.append(
  199. WorkflowDiagnostic(
  200. severity="warning",
  201. code="workflow.cycle_detected",
  202. message="workflow graph contains a cycle; debugger preview may be truncated")
  203. )
  204. node_inspections = [
  205. WorkflowNodeInspection(
  206. id=node.id,
  207. type=node.type,
  208. name=node.name,
  209. incoming_count=incoming_counts.get(node.id, 0),
  210. outgoing_count=outgoing_counts.get(node.id, 0),
  211. reachable=node.id in reachable_node_ids)
  212. for node in workflow.nodes
  213. ]
  214. valid = not any(item.severity == "error" for item in diagnostics)
  215. return WorkflowInspection(
  216. valid=valid,
  217. diagnostics=diagnostics,
  218. workflow=workflow,
  219. nodes=node_inspections,
  220. edges=edge_inspections,
  221. entry_node_ids=entry_node_ids,
  222. terminal_node_ids=terminal_node_ids,
  223. isolated_node_ids=isolated_node_ids,
  224. unreachable_node_ids=unreachable_node_ids,
  225. cycle_detected=cycle_detected)
  226. def build_debug_plan(
  227. payload: dict[str, JSONValue] | None,
  228. *,
  229. max_preview_steps: int = 50) -> WorkflowDebugPlan:
  230. inspection = inspect_workflow_dsl(payload)
  231. workflow = inspection.workflow
  232. if workflow is None or not inspection.valid:
  233. return WorkflowDebugPlan(
  234. inspection=inspection,
  235. execution_preview=[],
  236. max_preview_steps=max_preview_steps,
  237. truncated=False)
  238. node_map = {node.id: node for node in workflow.nodes}
  239. adjacency: dict[str, list[str]] = {node.id: [] for node in workflow.nodes}
  240. for edge in workflow.edges:
  241. if edge.source in node_map and edge.target in node_map:
  242. adjacency.setdefault(edge.source, []).append(edge.target)
  243. preview: list[WorkflowDebugStep] = []
  244. queue = list(inspection.entry_node_ids)
  245. seen_visits: dict[str, int] = {}
  246. truncated = False
  247. while queue and len(preview) < max_preview_steps:
  248. node_id = queue.pop(0)
  249. node = node_map.get(node_id)
  250. if node is None:
  251. continue
  252. seen_visits[node_id] = seen_visits.get(node_id, 0) + 1
  253. if seen_visits[node_id] > 1:
  254. continue
  255. next_node_ids = adjacency.get(node_id, [])
  256. preview.append(
  257. WorkflowDebugStep(
  258. step_index=len(preview),
  259. node_id=node.id,
  260. node_type=node.type,
  261. name=node.name,
  262. next_node_ids=next_node_ids)
  263. )
  264. queue.extend(next_node_ids)
  265. if queue:
  266. truncated = True
  267. return WorkflowDebugPlan(
  268. inspection=inspection,
  269. execution_preview=preview,
  270. max_preview_steps=max_preview_steps,
  271. truncated=truncated)
  272. def _find_reachable_nodes(entry_node_ids: list[str], adjacency: dict[str, list[str]]) -> set[str]:
  273. reachable: set[str] = set()
  274. stack = list(entry_node_ids)
  275. while stack:
  276. node_id = stack.pop()
  277. if node_id in reachable:
  278. continue
  279. reachable.add(node_id)
  280. stack.extend(adjacency.get(node_id, []))
  281. return reachable
  282. def _detect_cycle(node_ids: list[str], adjacency: dict[str, list[str]]) -> bool:
  283. visiting: set[str] = set()
  284. visited: set[str] = set()
  285. def visit(node_id: str) -> bool:
  286. if node_id in visiting:
  287. return True
  288. if node_id in visited:
  289. return False
  290. visiting.add(node_id)
  291. for next_node_id in adjacency.get(node_id, []):
  292. if visit(next_node_id):
  293. return True
  294. visiting.remove(node_id)
  295. visited.add(node_id)
  296. return False
  297. return any(visit(node_id) for node_id in node_ids)