planner.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from core_domain import InitialNodeContract, WorkflowVersionContract
  2. from core_shared import JSONValue
  3. def derive_initial_node(workflow_version: WorkflowVersionContract) -> InitialNodeContract | None:
  4. dsl = workflow_version.dsl_json
  5. if not isinstance(dsl, dict):
  6. return None
  7. nodes_value = dsl.get("nodes")
  8. if not isinstance(nodes_value, list):
  9. return None
  10. nodes: list[dict[str, JSONValue]] = [
  11. item for item in nodes_value if isinstance(item, dict)
  12. ]
  13. if not nodes:
  14. return None
  15. edges_value = dsl.get("edges")
  16. incoming_targets = _collect_incoming_targets(edges_value)
  17. for node in nodes:
  18. node_id = node.get("id")
  19. node_type = node.get("type")
  20. if isinstance(node_id, str) and isinstance(node_type, str) and node_id not in incoming_targets:
  21. return InitialNodeContract(node_id=node_id, node_type=node_type, status="queued")
  22. first = nodes[0]
  23. first_id = first.get("id")
  24. first_type = first.get("type")
  25. if isinstance(first_id, str) and isinstance(first_type, str):
  26. return InitialNodeContract(node_id=first_id, node_type=first_type, status="queued")
  27. return None
  28. def derive_successor_nodes(
  29. workflow_version: WorkflowVersionContract,
  30. current_node_id: str,
  31. ) -> list[InitialNodeContract]:
  32. dsl = workflow_version.dsl_json
  33. if not isinstance(dsl, dict):
  34. return []
  35. nodes_value = dsl.get("nodes")
  36. edges_value = dsl.get("edges")
  37. if not isinstance(nodes_value, list) or not isinstance(edges_value, list):
  38. return []
  39. node_type_map = _build_node_type_map(nodes_value)
  40. successor_ids = _collect_successor_ids(edges_value, current_node_id)
  41. successors: list[InitialNodeContract] = []
  42. for successor_id in successor_ids:
  43. node_type = node_type_map.get(successor_id)
  44. if node_type is None:
  45. continue
  46. successors.append(
  47. InitialNodeContract(
  48. node_id=successor_id,
  49. node_type=node_type,
  50. status="queued",
  51. )
  52. )
  53. return successors
  54. def _collect_incoming_targets(edges_value: JSONValue | None) -> set[str]:
  55. if not isinstance(edges_value, list):
  56. return set()
  57. incoming_targets: set[str] = set()
  58. for item in edges_value:
  59. if not isinstance(item, dict):
  60. continue
  61. target = item.get("target")
  62. if isinstance(target, str):
  63. incoming_targets.add(target)
  64. return incoming_targets
  65. def _build_node_type_map(nodes_value: list[JSONValue]) -> dict[str, str]:
  66. node_type_map: dict[str, str] = {}
  67. for item in nodes_value:
  68. if not isinstance(item, dict):
  69. continue
  70. node_id = item.get("id")
  71. node_type = item.get("type")
  72. if isinstance(node_id, str) and isinstance(node_type, str):
  73. node_type_map[node_id] = node_type
  74. return node_type_map
  75. def _collect_successor_ids(edges_value: list[JSONValue], current_node_id: str) -> list[str]:
  76. successor_ids: list[str] = []
  77. for item in edges_value:
  78. if not isinstance(item, dict):
  79. continue
  80. source = item.get("source")
  81. target = item.get("target")
  82. if isinstance(source, str) and isinstance(target, str) and source == current_node_id:
  83. successor_ids.append(target)
  84. return successor_ids