designer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Literal
  4. from pydantic import ValidationError
  5. from core_dsl import WorkflowDefinition, parse_workflow_definition
  6. from core_shared import JSONValue
  7. DiagnosticSeverity = Literal["error", "warning", "info"]
  8. @dataclass(frozen=True)
  9. class WorkflowDiagnostic:
  10. severity: DiagnosticSeverity
  11. code: str
  12. message: str
  13. node_id: str | None = None
  14. edge_index: int | None = None
  15. @dataclass(frozen=True)
  16. class WorkflowNodeInspection:
  17. id: str
  18. type: str
  19. name: str | None
  20. incoming_count: int
  21. outgoing_count: int
  22. reachable: bool
  23. @dataclass(frozen=True)
  24. class WorkflowEdgeInspection:
  25. source: str
  26. target: str
  27. condition: str | None
  28. valid_source: bool
  29. valid_target: bool
  30. @dataclass(frozen=True)
  31. class WorkflowDebugStep:
  32. step_index: int
  33. node_id: str
  34. node_type: str
  35. name: str | None
  36. next_node_ids: list[str]
  37. @dataclass(frozen=True)
  38. class WorkflowInspection:
  39. valid: bool
  40. diagnostics: list[WorkflowDiagnostic]
  41. workflow: WorkflowDefinition | None
  42. nodes: list[WorkflowNodeInspection]
  43. edges: list[WorkflowEdgeInspection]
  44. entry_node_ids: list[str]
  45. terminal_node_ids: list[str]
  46. isolated_node_ids: list[str]
  47. unreachable_node_ids: list[str]
  48. cycle_detected: bool
  49. @dataclass(frozen=True)
  50. class WorkflowDebugPlan:
  51. inspection: WorkflowInspection
  52. execution_preview: list[WorkflowDebugStep]
  53. max_preview_steps: int
  54. truncated: bool
  55. def inspect_workflow_dsl(payload: dict[str, JSONValue] | None) -> WorkflowInspection:
  56. if payload is None:
  57. return WorkflowInspection(
  58. valid=False,
  59. diagnostics=[
  60. WorkflowDiagnostic(
  61. severity="error",
  62. code="dsl.required",
  63. message="workflow dsl_json is required",
  64. )
  65. ],
  66. workflow=None,
  67. nodes=[],
  68. edges=[],
  69. entry_node_ids=[],
  70. terminal_node_ids=[],
  71. isolated_node_ids=[],
  72. unreachable_node_ids=[],
  73. cycle_detected=False,
  74. )
  75. try:
  76. workflow = parse_workflow_definition(payload)
  77. except ValidationError as exc:
  78. return WorkflowInspection(
  79. valid=False,
  80. diagnostics=[
  81. WorkflowDiagnostic(
  82. severity="error",
  83. code="dsl.schema_invalid",
  84. message=str(exc),
  85. )
  86. ],
  87. workflow=None,
  88. nodes=[],
  89. edges=[],
  90. entry_node_ids=[],
  91. terminal_node_ids=[],
  92. isolated_node_ids=[],
  93. unreachable_node_ids=[],
  94. cycle_detected=False,
  95. )
  96. if workflow is None:
  97. return WorkflowInspection(
  98. valid=False,
  99. diagnostics=[
  100. WorkflowDiagnostic(
  101. severity="error",
  102. code="dsl.required",
  103. message="workflow dsl_json is required",
  104. )
  105. ],
  106. workflow=None,
  107. nodes=[],
  108. edges=[],
  109. entry_node_ids=[],
  110. terminal_node_ids=[],
  111. isolated_node_ids=[],
  112. unreachable_node_ids=[],
  113. cycle_detected=False,
  114. )
  115. diagnostics: list[WorkflowDiagnostic] = []
  116. node_ids = [node.id for node in workflow.nodes]
  117. node_id_set = set(node_ids)
  118. duplicate_node_ids = sorted({node_id for node_id in node_ids if node_ids.count(node_id) > 1})
  119. for node_id in duplicate_node_ids:
  120. diagnostics.append(
  121. WorkflowDiagnostic(
  122. severity="error",
  123. code="node.duplicate_id",
  124. message=f"duplicate node id: {node_id}",
  125. node_id=node_id,
  126. )
  127. )
  128. incoming_counts = {node_id: 0 for node_id in node_ids}
  129. outgoing_counts = {node_id: 0 for node_id in node_ids}
  130. adjacency: dict[str, list[str]] = {node_id: [] for node_id in node_ids}
  131. edge_inspections: list[WorkflowEdgeInspection] = []
  132. for edge_index, edge in enumerate(workflow.edges):
  133. valid_source = edge.source in node_id_set
  134. valid_target = edge.target in node_id_set
  135. edge_inspections.append(
  136. WorkflowEdgeInspection(
  137. source=edge.source,
  138. target=edge.target,
  139. condition=edge.condition,
  140. valid_source=valid_source,
  141. valid_target=valid_target,
  142. )
  143. )
  144. if not valid_source:
  145. diagnostics.append(
  146. WorkflowDiagnostic(
  147. severity="error",
  148. code="edge.source_missing",
  149. message=f"edge source node does not exist: {edge.source}",
  150. node_id=edge.source,
  151. edge_index=edge_index,
  152. )
  153. )
  154. if not valid_target:
  155. diagnostics.append(
  156. WorkflowDiagnostic(
  157. severity="error",
  158. code="edge.target_missing",
  159. message=f"edge target node does not exist: {edge.target}",
  160. node_id=edge.target,
  161. edge_index=edge_index,
  162. )
  163. )
  164. if valid_source and valid_target:
  165. outgoing_counts[edge.source] = outgoing_counts.get(edge.source, 0) + 1
  166. incoming_counts[edge.target] = incoming_counts.get(edge.target, 0) + 1
  167. adjacency.setdefault(edge.source, []).append(edge.target)
  168. if not workflow.nodes:
  169. diagnostics.append(
  170. WorkflowDiagnostic(
  171. severity="error",
  172. code="workflow.nodes_required",
  173. message="workflow must contain at least one node",
  174. )
  175. )
  176. entry_node_ids = [node_id for node_id in node_ids if incoming_counts.get(node_id, 0) == 0]
  177. terminal_node_ids = [node_id for node_id in node_ids if outgoing_counts.get(node_id, 0) == 0]
  178. isolated_node_ids = [
  179. node_id
  180. for node_id in node_ids
  181. if incoming_counts.get(node_id, 0) == 0 and outgoing_counts.get(node_id, 0) == 0
  182. ]
  183. if len(entry_node_ids) > 1:
  184. diagnostics.append(
  185. WorkflowDiagnostic(
  186. severity="warning",
  187. code="workflow.multiple_entry_nodes",
  188. message=f"workflow has multiple entry nodes: {', '.join(entry_node_ids)}",
  189. )
  190. )
  191. if not terminal_node_ids and workflow.nodes:
  192. diagnostics.append(
  193. WorkflowDiagnostic(
  194. severity="warning",
  195. code="workflow.no_terminal_node",
  196. message="workflow has no terminal node",
  197. )
  198. )
  199. reachable_node_ids = _find_reachable_nodes(entry_node_ids, adjacency)
  200. unreachable_node_ids = [node_id for node_id in node_ids if node_id not in reachable_node_ids]
  201. for node_id in unreachable_node_ids:
  202. diagnostics.append(
  203. WorkflowDiagnostic(
  204. severity="warning",
  205. code="node.unreachable",
  206. message=f"node is not reachable from an entry node: {node_id}",
  207. node_id=node_id,
  208. )
  209. )
  210. cycle_detected = _detect_cycle(node_ids, adjacency)
  211. if cycle_detected:
  212. diagnostics.append(
  213. WorkflowDiagnostic(
  214. severity="warning",
  215. code="workflow.cycle_detected",
  216. message="workflow graph contains a cycle; debugger preview may be truncated",
  217. )
  218. )
  219. node_inspections = [
  220. WorkflowNodeInspection(
  221. id=node.id,
  222. type=node.type,
  223. name=node.name,
  224. incoming_count=incoming_counts.get(node.id, 0),
  225. outgoing_count=outgoing_counts.get(node.id, 0),
  226. reachable=node.id in reachable_node_ids,
  227. )
  228. for node in workflow.nodes
  229. ]
  230. valid = not any(item.severity == "error" for item in diagnostics)
  231. return WorkflowInspection(
  232. valid=valid,
  233. diagnostics=diagnostics,
  234. workflow=workflow,
  235. nodes=node_inspections,
  236. edges=edge_inspections,
  237. entry_node_ids=entry_node_ids,
  238. terminal_node_ids=terminal_node_ids,
  239. isolated_node_ids=isolated_node_ids,
  240. unreachable_node_ids=unreachable_node_ids,
  241. cycle_detected=cycle_detected,
  242. )
  243. def build_debug_plan(
  244. payload: dict[str, JSONValue] | None,
  245. *,
  246. max_preview_steps: int = 50,
  247. ) -> WorkflowDebugPlan:
  248. inspection = inspect_workflow_dsl(payload)
  249. workflow = inspection.workflow
  250. if workflow is None or not inspection.valid:
  251. return WorkflowDebugPlan(
  252. inspection=inspection,
  253. execution_preview=[],
  254. max_preview_steps=max_preview_steps,
  255. truncated=False,
  256. )
  257. node_map = {node.id: node for node in workflow.nodes}
  258. adjacency: dict[str, list[str]] = {node.id: [] for node in workflow.nodes}
  259. for edge in workflow.edges:
  260. if edge.source in node_map and edge.target in node_map:
  261. adjacency.setdefault(edge.source, []).append(edge.target)
  262. preview: list[WorkflowDebugStep] = []
  263. queue = list(inspection.entry_node_ids)
  264. seen_visits: dict[str, int] = {}
  265. truncated = False
  266. while queue and len(preview) < max_preview_steps:
  267. node_id = queue.pop(0)
  268. node = node_map.get(node_id)
  269. if node is None:
  270. continue
  271. seen_visits[node_id] = seen_visits.get(node_id, 0) + 1
  272. if seen_visits[node_id] > 1:
  273. continue
  274. next_node_ids = adjacency.get(node_id, [])
  275. preview.append(
  276. WorkflowDebugStep(
  277. step_index=len(preview),
  278. node_id=node.id,
  279. node_type=node.type,
  280. name=node.name,
  281. next_node_ids=next_node_ids,
  282. )
  283. )
  284. queue.extend(next_node_ids)
  285. if queue:
  286. truncated = True
  287. return WorkflowDebugPlan(
  288. inspection=inspection,
  289. execution_preview=preview,
  290. max_preview_steps=max_preview_steps,
  291. truncated=truncated,
  292. )
  293. def _find_reachable_nodes(entry_node_ids: list[str], adjacency: dict[str, list[str]]) -> set[str]:
  294. reachable: set[str] = set()
  295. stack = list(entry_node_ids)
  296. while stack:
  297. node_id = stack.pop()
  298. if node_id in reachable:
  299. continue
  300. reachable.add(node_id)
  301. stack.extend(adjacency.get(node_id, []))
  302. return reachable
  303. def _detect_cycle(node_ids: list[str], adjacency: dict[str, list[str]]) -> bool:
  304. visiting: set[str] = set()
  305. visited: set[str] = set()
  306. def visit(node_id: str) -> bool:
  307. if node_id in visiting:
  308. return True
  309. if node_id in visited:
  310. return False
  311. visiting.add(node_id)
  312. for next_node_id in adjacency.get(node_id, []):
  313. if visit(next_node_id):
  314. return True
  315. visiting.remove(node_id)
  316. visited.add(node_id)
  317. return False
  318. return any(visit(node_id) for node_id in node_ids)