| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- from __future__ import annotations
- from dataclasses import dataclass
- from typing import Literal
- from core_dsl import WorkflowDefinition, parse_workflow_definition
- from core_shared import JSONValue
- from pydantic import ValidationError
- DiagnosticSeverity = Literal["error", "warning", "info"]
- @dataclass(frozen=True)
- class WorkflowDiagnostic:
- severity: DiagnosticSeverity
- code: str
- message: str
- node_id: str | None = None
- edge_index: int | None = None
- @dataclass(frozen=True)
- class WorkflowNodeInspection:
- id: str
- type: str
- name: str | None
- incoming_count: int
- outgoing_count: int
- reachable: bool
- @dataclass(frozen=True)
- class WorkflowEdgeInspection:
- source: str
- target: str
- condition: str | None
- valid_source: bool
- valid_target: bool
- @dataclass(frozen=True)
- class WorkflowDebugStep:
- step_index: int
- node_id: str
- node_type: str
- name: str | None
- next_node_ids: list[str]
- @dataclass(frozen=True)
- class WorkflowInspection:
- valid: bool
- diagnostics: list[WorkflowDiagnostic]
- workflow: WorkflowDefinition | None
- nodes: list[WorkflowNodeInspection]
- edges: list[WorkflowEdgeInspection]
- entry_node_ids: list[str]
- terminal_node_ids: list[str]
- isolated_node_ids: list[str]
- unreachable_node_ids: list[str]
- cycle_detected: bool
- @dataclass(frozen=True)
- class WorkflowDebugPlan:
- inspection: WorkflowInspection
- execution_preview: list[WorkflowDebugStep]
- max_preview_steps: int
- truncated: bool
- def inspect_workflow_dsl(payload: dict[str, JSONValue] | None) -> WorkflowInspection:
- if payload is None:
- return WorkflowInspection(
- valid=False,
- diagnostics=[
- WorkflowDiagnostic(
- severity="error",
- code="dsl.required",
- message="workflow dsl_json is required")
- ],
- workflow=None,
- nodes=[],
- edges=[],
- entry_node_ids=[],
- terminal_node_ids=[],
- isolated_node_ids=[],
- unreachable_node_ids=[],
- cycle_detected=False)
- try:
- workflow = parse_workflow_definition(payload)
- except ValidationError as exc:
- return WorkflowInspection(
- valid=False,
- diagnostics=[
- WorkflowDiagnostic(
- severity="error",
- code="dsl.schema_invalid",
- message=str(exc))
- ],
- workflow=None,
- nodes=[],
- edges=[],
- entry_node_ids=[],
- terminal_node_ids=[],
- isolated_node_ids=[],
- unreachable_node_ids=[],
- cycle_detected=False)
- if workflow is None:
- return WorkflowInspection(
- valid=False,
- diagnostics=[
- WorkflowDiagnostic(
- severity="error",
- code="dsl.required",
- message="workflow dsl_json is required")
- ],
- workflow=None,
- nodes=[],
- edges=[],
- entry_node_ids=[],
- terminal_node_ids=[],
- isolated_node_ids=[],
- unreachable_node_ids=[],
- cycle_detected=False)
- diagnostics: list[WorkflowDiagnostic] = []
- node_ids = [node.id for node in workflow.nodes]
- node_id_set = set(node_ids)
- duplicate_node_ids = sorted({node_id for node_id in node_ids if node_ids.count(node_id) > 1})
- for node_id in duplicate_node_ids:
- diagnostics.append(
- WorkflowDiagnostic(
- severity="error",
- code="node.duplicate_id",
- message=f"duplicate node id: {node_id}",
- node_id=node_id)
- )
- incoming_counts = {node_id: 0 for node_id in node_ids}
- outgoing_counts = {node_id: 0 for node_id in node_ids}
- adjacency: dict[str, list[str]] = {node_id: [] for node_id in node_ids}
- edge_inspections: list[WorkflowEdgeInspection] = []
- for edge_index, edge in enumerate(workflow.edges):
- valid_source = edge.source in node_id_set
- valid_target = edge.target in node_id_set
- edge_inspections.append(
- WorkflowEdgeInspection(
- source=edge.source,
- target=edge.target,
- condition=edge.condition,
- valid_source=valid_source,
- valid_target=valid_target)
- )
- if not valid_source:
- diagnostics.append(
- WorkflowDiagnostic(
- severity="error",
- code="edge.source_missing",
- message=f"edge source node does not exist: {edge.source}",
- node_id=edge.source,
- edge_index=edge_index)
- )
- if not valid_target:
- diagnostics.append(
- WorkflowDiagnostic(
- severity="error",
- code="edge.target_missing",
- message=f"edge target node does not exist: {edge.target}",
- node_id=edge.target,
- edge_index=edge_index)
- )
- if valid_source and valid_target:
- outgoing_counts[edge.source] = outgoing_counts.get(edge.source, 0) + 1
- incoming_counts[edge.target] = incoming_counts.get(edge.target, 0) + 1
- adjacency.setdefault(edge.source, []).append(edge.target)
- if not workflow.nodes:
- diagnostics.append(
- WorkflowDiagnostic(
- severity="error",
- code="workflow.nodes_required",
- message="workflow must contain at least one node")
- )
- entry_node_ids = [node_id for node_id in node_ids if incoming_counts.get(node_id, 0) == 0]
- terminal_node_ids = [node_id for node_id in node_ids if outgoing_counts.get(node_id, 0) == 0]
- isolated_node_ids = [
- node_id
- for node_id in node_ids
- if incoming_counts.get(node_id, 0) == 0 and outgoing_counts.get(node_id, 0) == 0
- ]
- if len(entry_node_ids) > 1:
- diagnostics.append(
- WorkflowDiagnostic(
- severity="warning",
- code="workflow.multiple_entry_nodes",
- message=f"workflow has multiple entry nodes: {', '.join(entry_node_ids)}")
- )
- if not terminal_node_ids and workflow.nodes:
- diagnostics.append(
- WorkflowDiagnostic(
- severity="warning",
- code="workflow.no_terminal_node",
- message="workflow has no terminal node")
- )
- reachable_node_ids = _find_reachable_nodes(entry_node_ids, adjacency)
- unreachable_node_ids = [node_id for node_id in node_ids if node_id not in reachable_node_ids]
- for node_id in unreachable_node_ids:
- diagnostics.append(
- WorkflowDiagnostic(
- severity="warning",
- code="node.unreachable",
- message=f"node is not reachable from an entry node: {node_id}",
- node_id=node_id)
- )
- cycle_detected = _detect_cycle(node_ids, adjacency)
- if cycle_detected:
- diagnostics.append(
- WorkflowDiagnostic(
- severity="warning",
- code="workflow.cycle_detected",
- message="workflow graph contains a cycle; debugger preview may be truncated")
- )
- node_inspections = [
- WorkflowNodeInspection(
- id=node.id,
- type=node.type,
- name=node.name,
- incoming_count=incoming_counts.get(node.id, 0),
- outgoing_count=outgoing_counts.get(node.id, 0),
- reachable=node.id in reachable_node_ids)
- for node in workflow.nodes
- ]
- valid = not any(item.severity == "error" for item in diagnostics)
- return WorkflowInspection(
- valid=valid,
- diagnostics=diagnostics,
- workflow=workflow,
- nodes=node_inspections,
- edges=edge_inspections,
- entry_node_ids=entry_node_ids,
- terminal_node_ids=terminal_node_ids,
- isolated_node_ids=isolated_node_ids,
- unreachable_node_ids=unreachable_node_ids,
- cycle_detected=cycle_detected)
- def build_debug_plan(
- payload: dict[str, JSONValue] | None,
- *,
- max_preview_steps: int = 50) -> WorkflowDebugPlan:
- inspection = inspect_workflow_dsl(payload)
- workflow = inspection.workflow
- if workflow is None or not inspection.valid:
- return WorkflowDebugPlan(
- inspection=inspection,
- execution_preview=[],
- max_preview_steps=max_preview_steps,
- truncated=False)
- node_map = {node.id: node for node in workflow.nodes}
- adjacency: dict[str, list[str]] = {node.id: [] for node in workflow.nodes}
- for edge in workflow.edges:
- if edge.source in node_map and edge.target in node_map:
- adjacency.setdefault(edge.source, []).append(edge.target)
- preview: list[WorkflowDebugStep] = []
- queue = list(inspection.entry_node_ids)
- seen_visits: dict[str, int] = {}
- truncated = False
- while queue and len(preview) < max_preview_steps:
- node_id = queue.pop(0)
- node = node_map.get(node_id)
- if node is None:
- continue
- seen_visits[node_id] = seen_visits.get(node_id, 0) + 1
- if seen_visits[node_id] > 1:
- continue
- next_node_ids = adjacency.get(node_id, [])
- preview.append(
- WorkflowDebugStep(
- step_index=len(preview),
- node_id=node.id,
- node_type=node.type,
- name=node.name,
- next_node_ids=next_node_ids)
- )
- queue.extend(next_node_ids)
- if queue:
- truncated = True
- return WorkflowDebugPlan(
- inspection=inspection,
- execution_preview=preview,
- max_preview_steps=max_preview_steps,
- truncated=truncated)
- def _find_reachable_nodes(entry_node_ids: list[str], adjacency: dict[str, list[str]]) -> set[str]:
- reachable: set[str] = set()
- stack = list(entry_node_ids)
- while stack:
- node_id = stack.pop()
- if node_id in reachable:
- continue
- reachable.add(node_id)
- stack.extend(adjacency.get(node_id, []))
- return reachable
- def _detect_cycle(node_ids: list[str], adjacency: dict[str, list[str]]) -> bool:
- visiting: set[str] = set()
- visited: set[str] = set()
- def visit(node_id: str) -> bool:
- if node_id in visiting:
- return True
- if node_id in visited:
- return False
- visiting.add(node_id)
- for next_node_id in adjacency.get(node_id, []):
- if visit(next_node_id):
- return True
- visiting.remove(node_id)
- visited.add(node_id)
- return False
- return any(visit(node_id) for node_id in node_ids)
|