context.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import json
  2. import re
  3. from collections.abc import Callable
  4. from core_shared import JSONValue
  5. TEMPLATE_PATTERN = re.compile(r"\{\{\s*(?P<expr>[^{}]+?)\s*\}\}")
  6. COMPARISON_OPERATORS = ("==", "!=", ">=", "<=", ">", "<")
  7. def build_template_context(
  8. *,
  9. node_id: str,
  10. node_type: str,
  11. run_state_json: dict[str, JSONValue],
  12. node_output_json_by_node_id: dict[str, dict[str, JSONValue]],
  13. node_output_text_by_node_id: dict[str, str],
  14. ) -> dict[str, JSONValue]:
  15. current_node_outputs = node_output_json_by_node_id.get(node_id, {})
  16. current_node_text = node_output_text_by_node_id.get(node_id)
  17. return {
  18. "state": run_state_json,
  19. "nodes": {
  20. item_node_id: {
  21. "output": output_json,
  22. "text": node_output_text_by_node_id.get(item_node_id),
  23. }
  24. for item_node_id, output_json in node_output_json_by_node_id.items()
  25. },
  26. "current": {
  27. "node_id": node_id,
  28. "node_type": node_type,
  29. "output": current_node_outputs,
  30. "text": current_node_text,
  31. },
  32. }
  33. def render_template_string(template: str, context: dict[str, JSONValue]) -> str:
  34. def replace(match: re.Match[str]) -> str:
  35. expression = match.group("expr").strip()
  36. value = resolve_expression(context, expression)
  37. if value is None:
  38. return ""
  39. if isinstance(value, (dict, list)):
  40. return json.dumps(value, ensure_ascii=True, separators=(",", ":"))
  41. return str(value)
  42. return TEMPLATE_PATTERN.sub(replace, template)
  43. def render_json_value(value: JSONValue, context: dict[str, JSONValue]) -> JSONValue:
  44. if isinstance(value, str):
  45. return render_template_string(value, context)
  46. if isinstance(value, list):
  47. return [render_json_value(item, context) for item in value]
  48. if isinstance(value, dict):
  49. return {
  50. str(item_key): render_json_value(item_value, context)
  51. for item_key, item_value in value.items()
  52. }
  53. return value
  54. def evaluate_condition_expression(expression: str, context: dict[str, JSONValue]) -> bool:
  55. stripped_expression = expression.strip()
  56. if not stripped_expression:
  57. return False
  58. for operator in COMPARISON_OPERATORS:
  59. if operator in stripped_expression:
  60. left_text, right_text = stripped_expression.split(operator, 1)
  61. left_value = resolve_expression(context, left_text.strip())
  62. right_value = resolve_expression(context, right_text.strip())
  63. return compare_values(left_value, right_value, operator)
  64. resolved = resolve_expression(context, stripped_expression)
  65. return coerce_bool(resolved)
  66. def resolve_expression(context: dict[str, JSONValue], expression: str) -> JSONValue:
  67. if expression == "":
  68. return None
  69. if (expression.startswith('"') and expression.endswith('"')) or (
  70. expression.startswith("'") and expression.endswith("'")
  71. ):
  72. return expression[1:-1]
  73. lowered = expression.lower()
  74. if lowered == "true":
  75. return True
  76. if lowered == "false":
  77. return False
  78. if lowered == "null":
  79. return None
  80. integer_value = try_parse_int(expression)
  81. if integer_value is not None:
  82. return integer_value
  83. float_value = try_parse_float(expression)
  84. if float_value is not None:
  85. return float_value
  86. return resolve_reference(context, expression)
  87. def resolve_reference(context: dict[str, JSONValue], path: str) -> JSONValue:
  88. current: JSONValue = context
  89. for segment in path.split("."):
  90. if not segment:
  91. return None
  92. if isinstance(current, dict):
  93. current = current.get(segment)
  94. continue
  95. if isinstance(current, list) and segment.isdigit():
  96. index = int(segment)
  97. if index < 0 or index >= len(current):
  98. return None
  99. current = current[index]
  100. continue
  101. return None
  102. return current
  103. def coerce_bool(value: JSONValue) -> bool:
  104. if isinstance(value, bool):
  105. return value
  106. if value is None:
  107. return False
  108. if isinstance(value, (int, float)):
  109. return value != 0
  110. if isinstance(value, str):
  111. lowered = value.strip().lower()
  112. if lowered in {"", "false", "0", "null", "none"}:
  113. return False
  114. return True
  115. if isinstance(value, (list, dict)):
  116. return len(value) > 0
  117. return False
  118. def compare_values(left: JSONValue, right: JSONValue, operator: str) -> bool:
  119. if operator == "==":
  120. return left == right
  121. if operator == "!=":
  122. return left != right
  123. if operator == ">":
  124. return compare_order(left, right, lambda x, y: x > y)
  125. if operator == "<":
  126. return compare_order(left, right, lambda x, y: x < y)
  127. if operator == ">=":
  128. return compare_order(left, right, lambda x, y: x >= y)
  129. if operator == "<=":
  130. return compare_order(left, right, lambda x, y: x <= y)
  131. return False
  132. def compare_order(
  133. left: JSONValue,
  134. right: JSONValue,
  135. operator: Callable[[int | float | str, int | float | str], bool],
  136. ) -> bool:
  137. if isinstance(left, (int, float)) and isinstance(right, (int, float)):
  138. return bool(operator(left, right))
  139. if isinstance(left, str) and isinstance(right, str):
  140. return bool(operator(left, right))
  141. return False
  142. def try_parse_int(value: str) -> int | None:
  143. if not value or any(item in value for item in {".", "e", "E"}):
  144. return None
  145. if value.startswith(("+", "-")):
  146. digits = value[1:]
  147. else:
  148. digits = value
  149. if not digits.isdigit():
  150. return None
  151. return int(value)
  152. def try_parse_float(value: str) -> float | None:
  153. try:
  154. parsed = float(value)
  155. except ValueError:
  156. return None
  157. if parsed.is_integer() and "." not in value and "e" not in value.lower():
  158. return None
  159. return parsed