smoke_runtime_no_key.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. from __future__ import annotations
  2. import json
  3. import os
  4. import sys
  5. import uuid
  6. from dataclasses import dataclass
  7. import httpx
  8. WORKFLOW_SERVICE_URL = os.getenv(
  9. "AGENT_PLATFORM_SMOKE_WORKFLOW_URL",
  10. "http://127.0.0.1:8002/workflows",
  11. )
  12. RUNTIME_SERVICE_URL = os.getenv(
  13. "AGENT_PLATFORM_SMOKE_RUNTIME_URL",
  14. "http://127.0.0.1:8003/runtime",
  15. )
  16. TENANT_ID = os.getenv("AGENT_PLATFORM_SMOKE_TENANT_ID", "t-smoke")
  17. SMOKE_API_KEY = os.getenv("AGENT_PLATFORM_SMOKE_API_KEY")
  18. @dataclass(frozen=True)
  19. class SmokeScenario:
  20. score: int
  21. expected_branch_node_id: str
  22. expected_output_text: str
  23. SCENARIOS = (
  24. SmokeScenario(
  25. score=7,
  26. expected_branch_node_id="high_path",
  27. expected_output_text="Alice passed with score 7",
  28. ),
  29. SmokeScenario(
  30. score=3,
  31. expected_branch_node_id="low_path",
  32. expected_output_text="Alice did not pass; score 3",
  33. ),
  34. )
  35. def main() -> int:
  36. unique_suffix = uuid.uuid4().hex[:8]
  37. headers = {"x-tenant-id": TENANT_ID}
  38. if SMOKE_API_KEY:
  39. headers["x-api-key"] = SMOKE_API_KEY
  40. with httpx.Client(timeout=20.0, headers=headers) as client:
  41. app_id = create_app(client, unique_suffix)
  42. workflow_id = create_workflow(client, app_id, unique_suffix)
  43. results: list[dict[str, object]] = []
  44. for scenario in SCENARIOS:
  45. results.append(run_scenario(client, app_id, workflow_id, unique_suffix, scenario))
  46. results.append(run_retriever_scenario(client, app_id, workflow_id, unique_suffix))
  47. print(json.dumps(results, ensure_ascii=False, indent=2))
  48. return 0
  49. def create_app(client: httpx.Client, unique_suffix: str) -> str:
  50. response = client.post(
  51. f"{WORKFLOW_SERVICE_URL}/apps",
  52. json={
  53. "tenant_id": TENANT_ID,
  54. "code": f"smoke-app-{unique_suffix}",
  55. "name": f"Smoke App {unique_suffix}",
  56. },
  57. )
  58. response.raise_for_status()
  59. payload = response.json()
  60. return str(payload["id"])
  61. def create_workflow(client: httpx.Client, app_id: str, unique_suffix: str) -> str:
  62. response = client.post(
  63. WORKFLOW_SERVICE_URL,
  64. json={
  65. "tenant_id": TENANT_ID,
  66. "app_id": app_id,
  67. "code": f"smoke-flow-{unique_suffix}",
  68. "name": f"Smoke Flow {unique_suffix}",
  69. },
  70. )
  71. response.raise_for_status()
  72. payload = response.json()
  73. return str(payload["id"])
  74. def run_scenario(
  75. client: httpx.Client,
  76. app_id: str,
  77. workflow_id: str,
  78. unique_suffix: str,
  79. scenario: SmokeScenario,
  80. ) -> dict[str, object]:
  81. workflow_version_id = create_workflow_version(client, workflow_id, unique_suffix, scenario.score)
  82. app_version_id = create_app_version(client, app_id, workflow_version_id)
  83. run_id = create_run(client, app_id, app_version_id, workflow_id, workflow_version_id)
  84. execute_run(client, run_id)
  85. node_runs = list_node_runs(client, run_id)
  86. artifacts = list_node_artifacts(client, run_id)
  87. if len(artifacts) < 3:
  88. raise AssertionError(f"expected at least 3 artifacts, got {len(artifacts)}")
  89. trace_spans = list_trace_spans(client, run_id)
  90. if len(trace_spans) < 3:
  91. raise AssertionError(f"expected at least 3 trace spans, got {len(trace_spans)}")
  92. node_map = {str(item["node_id"]): item for item in node_runs}
  93. assert scenario.expected_branch_node_id in node_map, (
  94. f"expected branch node not found: {scenario.expected_branch_node_id}"
  95. )
  96. expected_node = node_map[scenario.expected_branch_node_id]
  97. actual_output_text = expected_node.get("output_text")
  98. if actual_output_text != scenario.expected_output_text:
  99. raise AssertionError(
  100. f"unexpected output_text for {scenario.expected_branch_node_id}: {actual_output_text!r}"
  101. )
  102. other_branch_node_id = "low_path" if scenario.expected_branch_node_id == "high_path" else "high_path"
  103. if other_branch_node_id in node_map:
  104. raise AssertionError(f"unexpected branch node executed: {other_branch_node_id}")
  105. return {
  106. "score": scenario.score,
  107. "executed_node_ids": [str(item["node_id"]) for item in node_runs],
  108. "branch_output_text": actual_output_text,
  109. "artifact_count": len(artifacts),
  110. "trace_span_count": len(trace_spans),
  111. }
  112. def run_retriever_scenario(
  113. client: httpx.Client,
  114. app_id: str,
  115. workflow_id: str,
  116. unique_suffix: str,
  117. ) -> dict[str, object]:
  118. workflow_version_id = create_retriever_workflow_version(client, workflow_id, unique_suffix)
  119. app_version_id = create_app_version(client, app_id, workflow_version_id)
  120. run_id = create_run(client, app_id, app_version_id, workflow_id, workflow_version_id)
  121. execute_run(client, run_id)
  122. node_runs = list_node_runs(client, run_id)
  123. artifacts = list_node_artifacts(client, run_id)
  124. if len(artifacts) < 3:
  125. raise AssertionError(f"expected at least 3 retriever artifacts, got {len(artifacts)}")
  126. trace_spans = list_trace_spans(client, run_id)
  127. if len(trace_spans) < 3:
  128. raise AssertionError(f"expected at least 3 retriever trace spans, got {len(trace_spans)}")
  129. node_map = {str(item["node_id"]): item for item in node_runs}
  130. answer_node = node_map.get("render_answer")
  131. if answer_node is None:
  132. raise AssertionError("retriever answer node was not executed")
  133. answer_text = answer_node.get("output_text")
  134. expected_answer_text = "Top doc: Refund Policy"
  135. if answer_text != expected_answer_text:
  136. raise AssertionError(f"unexpected retriever answer text: {answer_text!r}")
  137. retrieve_node = node_map.get("retrieve_docs")
  138. if retrieve_node is None:
  139. raise AssertionError("retriever node was not executed")
  140. retrieve_output = retrieve_node.get("output_json")
  141. if not isinstance(retrieve_output, dict):
  142. raise AssertionError("retriever output_json must be an object")
  143. return {
  144. "scenario": "retriever",
  145. "executed_node_ids": [str(item["node_id"]) for item in node_runs],
  146. "answer_text": answer_text,
  147. "artifact_count": len(artifacts),
  148. "trace_span_count": len(trace_spans),
  149. }
  150. def create_workflow_version(
  151. client: httpx.Client,
  152. workflow_id: str,
  153. unique_suffix: str,
  154. score: int,
  155. ) -> str:
  156. response = client.post(
  157. f"{WORKFLOW_SERVICE_URL}/versions",
  158. json={
  159. "tenant_id": TENANT_ID,
  160. "workflow_id": workflow_id,
  161. "status": "active",
  162. "dsl_json": build_workflow_dsl(unique_suffix, score),
  163. },
  164. )
  165. response.raise_for_status()
  166. payload = response.json()
  167. return str(payload["id"])
  168. def create_retriever_workflow_version(
  169. client: httpx.Client,
  170. workflow_id: str,
  171. unique_suffix: str,
  172. ) -> str:
  173. response = client.post(
  174. f"{WORKFLOW_SERVICE_URL}/versions",
  175. json={
  176. "tenant_id": TENANT_ID,
  177. "workflow_id": workflow_id,
  178. "status": "active",
  179. "dsl_json": build_retriever_workflow_dsl(unique_suffix),
  180. },
  181. )
  182. response.raise_for_status()
  183. payload = response.json()
  184. return str(payload["id"])
  185. def create_app_version(client: httpx.Client, app_id: str, workflow_version_id: str) -> str:
  186. response = client.post(
  187. f"{WORKFLOW_SERVICE_URL}/apps/versions",
  188. json={
  189. "tenant_id": TENANT_ID,
  190. "app_id": app_id,
  191. "workflow_version_id": workflow_version_id,
  192. "status": "active",
  193. },
  194. )
  195. response.raise_for_status()
  196. payload = response.json()
  197. return str(payload["id"])
  198. def create_run(
  199. client: httpx.Client,
  200. app_id: str,
  201. app_version_id: str,
  202. workflow_id: str,
  203. workflow_version_id: str,
  204. ) -> str:
  205. response = client.post(
  206. f"{RUNTIME_SERVICE_URL}/runs",
  207. json={
  208. "tenant_id": TENANT_ID,
  209. "app_id": app_id,
  210. "app_version_id": app_version_id,
  211. "workflow_id": workflow_id,
  212. "workflow_version_id": workflow_version_id,
  213. },
  214. )
  215. response.raise_for_status()
  216. payload = response.json()
  217. return str(payload["run"]["id"])
  218. def execute_run(client: httpx.Client, run_id: str) -> None:
  219. response = client.post(
  220. f"{RUNTIME_SERVICE_URL}/runs/{run_id}/execute",
  221. params={"tenant_id": TENANT_ID},
  222. json={"max_steps": 8},
  223. )
  224. response.raise_for_status()
  225. def list_node_runs(client: httpx.Client, run_id: str) -> list[dict[str, object]]:
  226. response = client.get(
  227. f"{RUNTIME_SERVICE_URL}/node-runs",
  228. params={"tenant_id": TENANT_ID, "run_id": run_id},
  229. )
  230. response.raise_for_status()
  231. payload = response.json()
  232. if not isinstance(payload, list):
  233. raise AssertionError("node-runs response must be a list")
  234. return [item for item in payload if isinstance(item, dict)]
  235. def list_node_artifacts(client: httpx.Client, run_id: str) -> list[dict[str, object]]:
  236. response = client.get(
  237. f"{RUNTIME_SERVICE_URL}/node-artifacts",
  238. params={"tenant_id": TENANT_ID, "run_id": run_id},
  239. )
  240. response.raise_for_status()
  241. payload = response.json()
  242. if not isinstance(payload, list):
  243. raise AssertionError("node-artifacts response must be a list")
  244. return [item for item in payload if isinstance(item, dict)]
  245. def list_trace_spans(client: httpx.Client, run_id: str) -> list[dict[str, object]]:
  246. response = client.get(
  247. f"{RUNTIME_SERVICE_URL}/trace-spans",
  248. params={"tenant_id": TENANT_ID, "run_id": run_id},
  249. )
  250. response.raise_for_status()
  251. payload = response.json()
  252. if not isinstance(payload, list):
  253. raise AssertionError("trace-spans response must be a list")
  254. return [item for item in payload if isinstance(item, dict)]
  255. def build_workflow_dsl(unique_suffix: str, score: int) -> dict[str, object]:
  256. return {
  257. "code": f"smoke-flow-{unique_suffix}-{score}",
  258. "name": f"Smoke Flow {score}",
  259. "nodes": [
  260. {
  261. "id": "seed_state",
  262. "type": "assigner",
  263. "config": {
  264. "assignments": {
  265. "score": score,
  266. "user_name": "Alice",
  267. },
  268. },
  269. },
  270. {
  271. "id": "check_score",
  272. "type": "if-else",
  273. "config": {
  274. "expression": "state.score >= 5",
  275. },
  276. },
  277. {
  278. "id": "high_path",
  279. "type": "template-transform",
  280. "config": {
  281. "template": "{{state.user_name}} passed with score {{state.score}}",
  282. },
  283. },
  284. {
  285. "id": "low_path",
  286. "type": "template-transform",
  287. "config": {
  288. "template": "{{state.user_name}} did not pass; score {{state.score}}",
  289. },
  290. },
  291. ],
  292. "edges": [
  293. {"source": "seed_state", "target": "check_score"},
  294. {"source": "check_score", "target": "high_path", "condition": "true"},
  295. {"source": "check_score", "target": "low_path", "condition": "false"},
  296. ],
  297. }
  298. def build_retriever_workflow_dsl(unique_suffix: str) -> dict[str, object]:
  299. return {
  300. "code": f"smoke-retriever-{unique_suffix}",
  301. "name": "Smoke Retriever Flow",
  302. "nodes": [
  303. {
  304. "id": "seed_query",
  305. "type": "assigner",
  306. "config": {
  307. "assignments": {
  308. "query": "refund policy",
  309. },
  310. },
  311. },
  312. {
  313. "id": "retrieve_docs",
  314. "type": "knowledge-retrieval",
  315. "config": {
  316. "query_template": "{{state.query}}",
  317. "top_k": 1,
  318. "documents": [
  319. {
  320. "id": "shipping",
  321. "title": "Shipping Policy",
  322. "text": "Shipping usually takes three to five business days.",
  323. },
  324. {
  325. "id": "refund",
  326. "title": "Refund Policy",
  327. "text": "Refund policy allows returns within seven days after delivery.",
  328. },
  329. ],
  330. },
  331. },
  332. {
  333. "id": "render_answer",
  334. "type": "template-transform",
  335. "config": {
  336. "template": "Top doc: {{nodes.retrieve_docs.output.retrieved_documents.0.title}}",
  337. },
  338. },
  339. ],
  340. "edges": [
  341. {"source": "seed_query", "target": "retrieve_docs"},
  342. {"source": "retrieve_docs", "target": "render_answer"},
  343. ],
  344. }
  345. if __name__ == "__main__":
  346. try:
  347. raise SystemExit(main())
  348. except Exception as exc:
  349. print(f"smoke test failed: {exc}", file=sys.stderr)
  350. raise