planner.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from core_domain import InitialNodeContract, WorkflowVersionContract
  2. from core_dsl import (
  3. EdgeDefinition,
  4. get_initial_node_definition,
  5. get_node_definition,
  6. parse_workflow_definition,
  7. )
  8. from core_shared import JSONValue
  9. from .context import build_template_context, evaluate_condition_expression
  10. def derive_initial_node(workflow_version: WorkflowVersionContract) -> InitialNodeContract | None:
  11. workflow = parse_workflow_definition(workflow_version.dsl_json)
  12. if workflow is None:
  13. return None
  14. node = get_initial_node_definition(workflow)
  15. if node is None:
  16. return None
  17. return InitialNodeContract(node_id=node.id, node_type=node.type, status="queued")
  18. def derive_successor_nodes(
  19. workflow_version: WorkflowVersionContract,
  20. current_node_id: str,
  21. current_output_json: dict[str, JSONValue] | None = None,
  22. run_state_json: dict[str, JSONValue] | None = None,
  23. node_output_json_by_node_id: dict[str, dict[str, JSONValue]] | None = None,
  24. node_output_text_by_node_id: dict[str, str] | None = None,
  25. ) -> list[InitialNodeContract]:
  26. workflow = parse_workflow_definition(workflow_version.dsl_json)
  27. if workflow is None:
  28. return []
  29. node_map = {node.id: node for node in workflow.nodes}
  30. template_context = build_template_context(
  31. node_id=current_node_id,
  32. node_type=node_map.get(current_node_id).type if current_node_id in node_map else "unknown",
  33. run_state_json=run_state_json or {},
  34. node_output_json_by_node_id=node_output_json_by_node_id or {},
  35. node_output_text_by_node_id=node_output_text_by_node_id or {},
  36. )
  37. edge_context: dict[str, JSONValue] = {
  38. **template_context,
  39. "output": current_output_json or {},
  40. "route": _read_string_value(current_output_json or {}, "route"),
  41. "condition_result": _read_bool_value(current_output_json or {}, "condition_result"),
  42. }
  43. successors: list[InitialNodeContract] = []
  44. for edge in _get_matching_edges(
  45. workflow.edges,
  46. current_node_id=current_node_id,
  47. edge_context=edge_context,
  48. ):
  49. successor = node_map.get(edge.target)
  50. if successor is None:
  51. continue
  52. successors.append(
  53. InitialNodeContract(
  54. node_id=successor.id,
  55. node_type=successor.type,
  56. status="queued",
  57. )
  58. )
  59. return successors
  60. def derive_node_config(
  61. workflow_version: WorkflowVersionContract,
  62. node_id: str,
  63. ) -> dict[str, JSONValue]:
  64. workflow = parse_workflow_definition(workflow_version.dsl_json)
  65. if workflow is None:
  66. return {}
  67. node = get_node_definition(workflow, node_id)
  68. if node is None:
  69. return {}
  70. return dict(node.config)
  71. def _get_matching_edges(
  72. edges: list[EdgeDefinition],
  73. *,
  74. current_node_id: str,
  75. edge_context: dict[str, JSONValue],
  76. ) -> list[EdgeDefinition]:
  77. matching_edges: list[EdgeDefinition] = []
  78. for edge in edges:
  79. if edge.source != current_node_id:
  80. continue
  81. if _matches_edge_condition(edge.condition, edge_context):
  82. matching_edges.append(edge)
  83. return matching_edges
  84. def _matches_edge_condition(
  85. condition: str | None,
  86. context: dict[str, JSONValue],
  87. ) -> bool:
  88. if condition is None or not condition.strip():
  89. return True
  90. stripped = condition.strip()
  91. route = context.get("route")
  92. if isinstance(route, str) and stripped == route:
  93. return True
  94. condition_result = context.get("condition_result")
  95. if isinstance(condition_result, bool) and stripped.lower() in {"true", "false"}:
  96. return condition_result is (stripped.lower() == "true")
  97. return evaluate_condition_expression(stripped, context)
  98. def _read_string_value(payload: dict[str, JSONValue], key: str) -> str | None:
  99. value = payload.get(key)
  100. if isinstance(value, str):
  101. return value
  102. return None
  103. def _read_bool_value(payload: dict[str, JSONValue], key: str) -> bool | None:
  104. value = payload.get(key)
  105. if isinstance(value, bool):
  106. return value
  107. return None