context.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import json
  2. import re
  3. from collections.abc import Callable
  4. from core_shared import JSONValue
  5. TEMPLATE_PATTERN = re.compile(r"\{\{\s*(?P<expr>[^{}]+?)\s*\}\}")
  6. COMPARISON_OPERATORS = ("==", "!=", ">=", "<=", ">", "<")
  7. def build_template_context(
  8. *,
  9. node_id: str,
  10. node_type: str,
  11. run_state_json: dict[str, JSONValue],
  12. node_output_json_by_node_id: dict[str, dict[str, JSONValue]],
  13. node_output_text_by_node_id: dict[str, str]) -> dict[str, JSONValue]:
  14. current_node_outputs = node_output_json_by_node_id.get(node_id, {})
  15. current_node_text = node_output_text_by_node_id.get(node_id)
  16. return {
  17. "state": run_state_json,
  18. "nodes": {
  19. item_node_id: {
  20. "output": output_json,
  21. "text": node_output_text_by_node_id.get(item_node_id),
  22. }
  23. for item_node_id, output_json in node_output_json_by_node_id.items()
  24. },
  25. "current": {
  26. "node_id": node_id,
  27. "node_type": node_type,
  28. "output": current_node_outputs,
  29. "text": current_node_text,
  30. },
  31. }
  32. def render_template_string(template: str, context: dict[str, JSONValue]) -> str:
  33. def replace(match: re.Match[str]) -> str:
  34. expression = match.group("expr").strip()
  35. value = resolve_expression(context, expression)
  36. if value is None:
  37. return ""
  38. if isinstance(value, (dict, list)):
  39. return json.dumps(value, ensure_ascii=True, separators=(",", ":"))
  40. return str(value)
  41. return TEMPLATE_PATTERN.sub(replace, template)
  42. def render_json_value(value: JSONValue, context: dict[str, JSONValue]) -> JSONValue:
  43. if isinstance(value, str):
  44. return render_template_string(value, context)
  45. if isinstance(value, list):
  46. return [render_json_value(item, context) for item in value]
  47. if isinstance(value, dict):
  48. return {
  49. str(item_key): render_json_value(item_value, context)
  50. for item_key, item_value in value.items()
  51. }
  52. return value
  53. def evaluate_condition_expression(expression: str, context: dict[str, JSONValue]) -> bool:
  54. stripped_expression = expression.strip()
  55. if not stripped_expression:
  56. return False
  57. for operator in COMPARISON_OPERATORS:
  58. if operator in stripped_expression:
  59. left_text, right_text = stripped_expression.split(operator, 1)
  60. left_value = resolve_expression(context, left_text.strip())
  61. right_value = resolve_expression(context, right_text.strip())
  62. return compare_values(left_value, right_value, operator)
  63. resolved = resolve_expression(context, stripped_expression)
  64. return coerce_bool(resolved)
  65. def resolve_expression(context: dict[str, JSONValue], expression: str) -> JSONValue:
  66. if expression == "":
  67. return None
  68. if (expression.startswith('"') and expression.endswith('"')) or (
  69. expression.startswith("'") and expression.endswith("'")
  70. ):
  71. return expression[1:-1]
  72. lowered = expression.lower()
  73. if lowered == "true":
  74. return True
  75. if lowered == "false":
  76. return False
  77. if lowered == "null":
  78. return None
  79. integer_value = try_parse_int(expression)
  80. if integer_value is not None:
  81. return integer_value
  82. float_value = try_parse_float(expression)
  83. if float_value is not None:
  84. return float_value
  85. return resolve_reference(context, expression)
  86. def resolve_reference(context: dict[str, JSONValue], path: str) -> JSONValue:
  87. current: JSONValue = context
  88. for segment in path.split("."):
  89. if not segment:
  90. return None
  91. if isinstance(current, dict):
  92. current = current.get(segment)
  93. continue
  94. if isinstance(current, list) and segment.isdigit():
  95. index = int(segment)
  96. if index < 0 or index >= len(current):
  97. return None
  98. current = current[index]
  99. continue
  100. return None
  101. return current
  102. def coerce_bool(value: JSONValue) -> bool:
  103. if isinstance(value, bool):
  104. return value
  105. if value is None:
  106. return False
  107. if isinstance(value, (int, float)):
  108. return value != 0
  109. if isinstance(value, str):
  110. lowered = value.strip().lower()
  111. if lowered in {"", "false", "0", "null", "none"}:
  112. return False
  113. return True
  114. if isinstance(value, (list, dict)):
  115. return len(value) > 0
  116. return False
  117. def compare_values(left: JSONValue, right: JSONValue, operator: str) -> bool:
  118. if operator == "==":
  119. return left == right
  120. if operator == "!=":
  121. return left != right
  122. if operator == ">":
  123. return compare_order(left, right, lambda x, y: x > y)
  124. if operator == "<":
  125. return compare_order(left, right, lambda x, y: x < y)
  126. if operator == ">=":
  127. return compare_order(left, right, lambda x, y: x >= y)
  128. if operator == "<=":
  129. return compare_order(left, right, lambda x, y: x <= y)
  130. return False
  131. def compare_order(
  132. left: JSONValue,
  133. right: JSONValue,
  134. operator: Callable[[int | float | str, int | float | str], bool]) -> bool:
  135. if isinstance(left, (int, float)) and isinstance(right, (int, float)):
  136. return bool(operator(left, right))
  137. if isinstance(left, str) and isinstance(right, str):
  138. return bool(operator(left, right))
  139. return False
  140. def try_parse_int(value: str) -> int | None:
  141. if not value or any(item in value for item in {".", "e", "E"}):
  142. return None
  143. if value.startswith(("+", "-")):
  144. digits = value[1:]
  145. else:
  146. digits = value
  147. if not digits.isdigit():
  148. return None
  149. return int(value)
  150. def try_parse_float(value: str) -> float | None:
  151. try:
  152. parsed = float(value)
  153. except ValueError:
  154. return None
  155. if parsed.is_integer() and "." not in value and "e" not in value.lower():
  156. return None
  157. return parsed