test_agent_react.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. from __future__ import annotations
  2. import sys
  3. from pathlib import Path
  4. REPO_ROOT = Path(__file__).resolve().parents[1]
  5. for module_name in list(sys.modules):
  6. if module_name == "app" or module_name.startswith("app."):
  7. del sys.modules[module_name]
  8. for path in [
  9. REPO_ROOT / "libs" / "core-domain" / "src",
  10. REPO_ROOT / "libs" / "core-shared" / "src",
  11. REPO_ROOT / "libs" / "core-db" / "src",
  12. REPO_ROOT / "libs" / "core-events" / "src",
  13. REPO_ROOT / "services" / "agent-service",
  14. ]:
  15. sys.path.insert(0, str(path))
  16. from core_db import Base
  17. from core_domain import ChatCompletionResponseContract
  18. from app.application.services import AgentApplicationService
  19. from app.bootstrap.settings import AgentServiceSettings
  20. from app.db.session import build_session_factory
  21. from app.domain.repositories import (
  22. AgentDefinitionRepository,
  23. AgentRunRepository,
  24. AgentToolInvocationRepository,
  25. AgentVersionRepository,
  26. )
  27. from app.schemas.agent import (
  28. AgentCreateRequest,
  29. AgentRunCreateRequest,
  30. AgentRunExecuteRequest,
  31. AgentVersionCreateRequest,
  32. )
  33. class FakeModelClient:
  34. def __init__(self) -> None:
  35. self.calls = 0
  36. def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
  37. self.calls += 1
  38. if self.calls == 1:
  39. return ChatCompletionResponseContract(
  40. model="fake",
  41. content=(
  42. '{"action":"tool","tool_code":"lookup_order",'
  43. '"input_json":{"order_id":"123"}}'
  44. ),
  45. )
  46. return ChatCompletionResponseContract(
  47. model="fake",
  48. content='{"action":"finish","answer":"Order lookup attempted."}',
  49. )
  50. class FakeFunctionCallingModelClient:
  51. def __init__(self) -> None:
  52. self.calls = 0
  53. self.request_tools: list[object] = []
  54. def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
  55. self.calls += 1
  56. self.request_tools.append(getattr(payload, "tools_json", []))
  57. if self.calls == 1:
  58. return ChatCompletionResponseContract(
  59. model="fake",
  60. content="",
  61. finish_reason="tool_calls",
  62. tool_calls_json=[
  63. {
  64. "id": "call_1",
  65. "type": "function",
  66. "function": {
  67. "name": "lookup_order",
  68. "arguments": '{"order_id":"123"}',
  69. },
  70. }
  71. ],
  72. )
  73. return ChatCompletionResponseContract(
  74. model="fake",
  75. content='{"action":"finish","answer":"Function call handled."}',
  76. )
  77. def test_react_loop_records_steps_and_tool_invocation(tmp_path: Path) -> None:
  78. session_factory = build_session_factory(
  79. settings=AgentServiceSettings(
  80. database_url=f"sqlite:///{tmp_path / 'agent_service.db'}",
  81. ),
  82. )
  83. engine = session_factory.kw["bind"]
  84. Base.metadata.create_all(bind=engine)
  85. with session_factory() as db:
  86. service = AgentApplicationService(
  87. agent_repository=AgentDefinitionRepository(db),
  88. agent_version_repository=AgentVersionRepository(db),
  89. agent_run_repository=AgentRunRepository(db),
  90. agent_tool_invocation_repository=AgentToolInvocationRepository(db),
  91. model_gateway_client=FakeModelClient(),
  92. memory_client=None,
  93. tool_client=None,
  94. skill_client=None,
  95. event_client=None,
  96. react_max_steps=3,
  97. )
  98. agent = service.create_agent(
  99. AgentCreateRequest(tenant_id="t1", code="react", name="React")
  100. )
  101. service.create_agent_version(
  102. AgentVersionCreateRequest(
  103. tenant_id="t1",
  104. agent_id=agent.id,
  105. status="published",
  106. system_prompt="Use ReAct.",
  107. model_config={"react_enabled": True, "react_max_steps": 3},
  108. memory_policy={"enabled": False, "write_enabled": False},
  109. tool_refs=[
  110. {"tool_code": "lookup_order", "required": True, "config_json": {}}
  111. ],
  112. )
  113. )
  114. run = service.create_agent_run(
  115. AgentRunCreateRequest(
  116. tenant_id="t1",
  117. agent_id=agent.id,
  118. input_text="check order",
  119. )
  120. )
  121. result = service.execute_agent_run(
  122. agent_run_id=run.id,
  123. payload=AgentRunExecuteRequest(tenant_id="t1", worker_key="test"),
  124. )
  125. assert result is not None
  126. assert result.status == "completed"
  127. assert result.output_text == "Order lookup attempted."
  128. assert result.output_json is not None
  129. assert result.output_json["react_enabled"] is True
  130. assert len(result.output_json["react_steps"]) == 2
  131. invocations = service.list_agent_tool_invocations(
  132. tenant_id="t1",
  133. agent_run_id=run.id,
  134. )
  135. assert len(invocations) == 2
  136. assert all(item.status == "skipped" for item in invocations)
  137. engine.dispose()
  138. def test_react_loop_accepts_openai_tool_calls(tmp_path: Path) -> None:
  139. session_factory = build_session_factory(
  140. settings=AgentServiceSettings(
  141. database_url=f"sqlite:///{tmp_path / 'agent_service.db'}",
  142. ),
  143. )
  144. engine = session_factory.kw["bind"]
  145. Base.metadata.create_all(bind=engine)
  146. with session_factory() as db:
  147. model_client = FakeFunctionCallingModelClient()
  148. service = AgentApplicationService(
  149. agent_repository=AgentDefinitionRepository(db),
  150. agent_version_repository=AgentVersionRepository(db),
  151. agent_run_repository=AgentRunRepository(db),
  152. agent_tool_invocation_repository=AgentToolInvocationRepository(db),
  153. model_gateway_client=model_client,
  154. memory_client=None,
  155. tool_client=None,
  156. skill_client=None,
  157. event_client=None,
  158. react_max_steps=3,
  159. )
  160. agent = service.create_agent(
  161. AgentCreateRequest(tenant_id="t1", code="react-fn", name="React Function")
  162. )
  163. service.create_agent_version(
  164. AgentVersionCreateRequest(
  165. tenant_id="t1",
  166. agent_id=agent.id,
  167. status="published",
  168. system_prompt="Use ReAct.",
  169. model_config={
  170. "react_enabled": True,
  171. "react_max_steps": 3,
  172. "function_calling_enabled": True,
  173. },
  174. memory_policy={"enabled": False, "write_enabled": False},
  175. tool_refs=[
  176. {"tool_code": "lookup_order", "required": True, "config_json": {}}
  177. ],
  178. )
  179. )
  180. run = service.create_agent_run(
  181. AgentRunCreateRequest(
  182. tenant_id="t1",
  183. agent_id=agent.id,
  184. input_text="check order",
  185. )
  186. )
  187. result = service.execute_agent_run(
  188. agent_run_id=run.id,
  189. payload=AgentRunExecuteRequest(tenant_id="t1", worker_key="test"),
  190. )
  191. assert result is not None
  192. assert result.status == "completed"
  193. assert result.output_text == "Function call handled."
  194. assert result.output_json is not None
  195. assert result.output_json["react_steps"][0]["action"]["tool_call_protocol"] == "openai"
  196. assert result.output_json["react_steps"][0]["action"]["input_json"] == {"order_id": "123"}
  197. assert model_client.request_tools[0]
  198. engine.dispose()