planner.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from core_domain import InitialNodeContract, WorkflowVersionContract
  2. from core_dsl import (
  3. EdgeDefinition,
  4. get_initial_node_definition,
  5. get_node_definition,
  6. parse_workflow_definition,
  7. )
  8. from core_shared import JSONValue
  9. from .context import build_template_context, evaluate_condition_expression
  10. def derive_initial_node(workflow_version: WorkflowVersionContract) -> InitialNodeContract | None:
  11. workflow = parse_workflow_definition(workflow_version.dsl_json)
  12. if workflow is None:
  13. return None
  14. node = get_initial_node_definition(workflow)
  15. if node is None:
  16. return None
  17. return InitialNodeContract(node_id=node.id, node_type=node.type, status="queued")
  18. def derive_successor_nodes(
  19. workflow_version: WorkflowVersionContract,
  20. current_node_id: str,
  21. current_output_json: dict[str, JSONValue] | None = None,
  22. run_state_json: dict[str, JSONValue] | None = None,
  23. node_output_json_by_node_id: dict[str, dict[str, JSONValue]] | None = None,
  24. node_output_text_by_node_id: dict[str, str] | None = None) -> list[InitialNodeContract]:
  25. workflow = parse_workflow_definition(workflow_version.dsl_json)
  26. if workflow is None:
  27. return []
  28. node_map = {node.id: node for node in workflow.nodes}
  29. template_context = build_template_context(
  30. node_id=current_node_id,
  31. node_type=node_map.get(current_node_id).type if current_node_id in node_map else "unknown",
  32. run_state_json=run_state_json or {},
  33. node_output_json_by_node_id=node_output_json_by_node_id or {},
  34. node_output_text_by_node_id=node_output_text_by_node_id or {})
  35. edge_context: dict[str, JSONValue] = {
  36. **template_context,
  37. "output": current_output_json or {},
  38. "route": _read_string_value(current_output_json or {}, "route"),
  39. "condition_result": _read_bool_value(current_output_json or {}, "condition_result"),
  40. }
  41. successors: list[InitialNodeContract] = []
  42. for edge in _get_matching_edges(
  43. workflow.edges,
  44. current_node_id=current_node_id,
  45. edge_context=edge_context):
  46. successor = node_map.get(edge.target)
  47. if successor is None:
  48. continue
  49. successors.append(
  50. InitialNodeContract(
  51. node_id=successor.id,
  52. node_type=successor.type,
  53. status="queued")
  54. )
  55. return successors
  56. def derive_node_config(
  57. workflow_version: WorkflowVersionContract,
  58. node_id: str) -> dict[str, JSONValue]:
  59. workflow = parse_workflow_definition(workflow_version.dsl_json)
  60. if workflow is None:
  61. return {}
  62. node = get_node_definition(workflow, node_id)
  63. if node is None:
  64. return {}
  65. return dict(node.config)
  66. def _get_matching_edges(
  67. edges: list[EdgeDefinition],
  68. *,
  69. current_node_id: str,
  70. edge_context: dict[str, JSONValue]) -> list[EdgeDefinition]:
  71. matching_edges: list[EdgeDefinition] = []
  72. for edge in edges:
  73. if edge.source != current_node_id:
  74. continue
  75. if _matches_edge_condition(edge.condition, edge_context):
  76. matching_edges.append(edge)
  77. return matching_edges
  78. def _matches_edge_condition(
  79. condition: str | None,
  80. context: dict[str, JSONValue]) -> bool:
  81. if condition is None or not condition.strip():
  82. return True
  83. stripped = condition.strip()
  84. route = context.get("route")
  85. if isinstance(route, str) and stripped == route:
  86. return True
  87. condition_result = context.get("condition_result")
  88. if isinstance(condition_result, bool) and stripped.lower() in {"true", "false"}:
  89. return condition_result is (stripped.lower() == "true")
  90. return evaluate_condition_expression(stripped, context)
  91. def _read_string_value(payload: dict[str, JSONValue], key: str) -> str | None:
  92. value = payload.get(key)
  93. if isinstance(value, str):
  94. return value
  95. return None
  96. def _read_bool_value(payload: dict[str, JSONValue], key: str) -> bool | None:
  97. value = payload.get(key)
  98. if isinstance(value, bool):
  99. return value
  100. return None