test_agent_react.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. from __future__ import annotations
  2. import sys
  3. from pathlib import Path
  4. REPO_ROOT = Path(__file__).resolve().parents[1]
  5. if str(REPO_ROOT) not in sys.path:
  6. sys.path.insert(0, str(REPO_ROOT))
  7. from tests.conftest import build_postgres_database_url, prepare_known_service_import
  8. prepare_known_service_import("agent-service")
  9. from core_domain import ChatCompletionResponseContract
  10. class FakeModelClient:
  11. def __init__(self) -> None:
  12. self.calls = 0
  13. def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
  14. self.calls += 1
  15. if self.calls == 1:
  16. return ChatCompletionResponseContract(
  17. model="fake",
  18. content=(
  19. '{"action":"tool","tool_code":"lookup_order",'
  20. '"input_json":{"order_id":"123"}}'
  21. ))
  22. return ChatCompletionResponseContract(
  23. model="fake",
  24. content='{"action":"finish","answer":"Order lookup attempted."}')
  25. class FakeFunctionCallingModelClient:
  26. def __init__(self) -> None:
  27. self.calls = 0
  28. self.request_tools: list[object] = []
  29. def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
  30. self.calls += 1
  31. self.request_tools.append(getattr(payload, "tools_json", []))
  32. if self.calls == 1:
  33. return ChatCompletionResponseContract(
  34. model="fake",
  35. content="",
  36. finish_reason="tool_calls",
  37. tool_calls_json=[
  38. {
  39. "id": "call_1",
  40. "type": "function",
  41. "function": {
  42. "name": "lookup_order",
  43. "arguments": '{"order_id":"123"}',
  44. },
  45. }
  46. ])
  47. return ChatCompletionResponseContract(
  48. model="fake",
  49. content='{"action":"finish","answer":"Function call handled."}')
  50. def test_react_loop_records_steps_and_tool_invocation(tmp_path: Path) -> None:
  51. prepare_known_service_import("agent-service")
  52. from app.application.services import AgentApplicationService
  53. from app.bootstrap.settings import AgentServiceSettings
  54. from app.db.session import build_session_factory
  55. from app.domain.repositories import (
  56. AgentDefinitionRepository,
  57. AgentRunRepository,
  58. AgentToolInvocationRepository,
  59. AgentConfigRepository,
  60. )
  61. from app.schemas.agent import (
  62. AgentCreateRequest,
  63. AgentRunCreateRequest,
  64. AgentRunExecuteRequest,
  65. AgentConfigCreateRequest,
  66. )
  67. from core_db import Base
  68. session_factory = build_session_factory(
  69. settings=AgentServiceSettings(
  70. database_url=build_postgres_database_url(tmp_path, "agent-service-react")))
  71. engine = session_factory.kw["bind"]
  72. Base.metadata.create_all(bind=engine)
  73. with session_factory() as db:
  74. service = AgentApplicationService(
  75. agent_repository=AgentDefinitionRepository(db),
  76. agent_config_repository=AgentConfigRepository(db),
  77. agent_run_repository=AgentRunRepository(db),
  78. agent_tool_invocation_repository=AgentToolInvocationRepository(db),
  79. model_gateway_client=FakeModelClient(),
  80. memory_client=None,
  81. tool_client=None,
  82. skill_client=None,
  83. event_client=None,
  84. react_max_steps=3)
  85. agent = service.create_agent(
  86. AgentCreateRequest(code="react", name="React")
  87. )
  88. service.create_agent_config(
  89. AgentConfigCreateRequest(
  90. agent_id=agent.id,
  91. system_prompt="Use ReAct.",
  92. model_config={"react_enabled": True, "react_max_steps": 3},
  93. memory_policy={"enabled": False, "write_enabled": False},
  94. tool_refs=[
  95. {"tool_code": "lookup_order", "required": True, "config_json": {}}
  96. ])
  97. )
  98. run = service.create_agent_run(
  99. AgentRunCreateRequest(
  100. agent_id=agent.id,
  101. input_text="check order")
  102. )
  103. result = service.execute_agent_run(
  104. agent_run_id=run.id,
  105. payload=AgentRunExecuteRequest(worker_key="test"))
  106. assert result is not None
  107. assert result.status == "completed"
  108. assert result.output_text == "Order lookup attempted."
  109. assert result.output_json is not None
  110. assert result.output_json["react_enabled"] is True
  111. assert len(result.output_json["react_steps"]) == 2
  112. invocations = service.list_agent_tool_invocations(
  113. agent_run_id=run.id)
  114. assert len(invocations) == 2
  115. assert all(item.status == "skipped" for item in invocations)
  116. engine.dispose()
  117. def test_react_loop_accepts_openai_tool_calls(tmp_path: Path) -> None:
  118. prepare_known_service_import("agent-service")
  119. from app.application.services import AgentApplicationService
  120. from app.bootstrap.settings import AgentServiceSettings
  121. from app.db.session import build_session_factory
  122. from app.domain.repositories import (
  123. AgentDefinitionRepository,
  124. AgentRunRepository,
  125. AgentToolInvocationRepository,
  126. AgentConfigRepository,
  127. )
  128. from app.schemas.agent import (
  129. AgentCreateRequest,
  130. AgentRunCreateRequest,
  131. AgentRunExecuteRequest,
  132. AgentConfigCreateRequest,
  133. )
  134. from core_db import Base
  135. session_factory = build_session_factory(
  136. settings=AgentServiceSettings(
  137. database_url=build_postgres_database_url(tmp_path, "agent-service-react-fn")))
  138. engine = session_factory.kw["bind"]
  139. Base.metadata.create_all(bind=engine)
  140. with session_factory() as db:
  141. model_client = FakeFunctionCallingModelClient()
  142. service = AgentApplicationService(
  143. agent_repository=AgentDefinitionRepository(db),
  144. agent_config_repository=AgentConfigRepository(db),
  145. agent_run_repository=AgentRunRepository(db),
  146. agent_tool_invocation_repository=AgentToolInvocationRepository(db),
  147. model_gateway_client=model_client,
  148. memory_client=None,
  149. tool_client=None,
  150. skill_client=None,
  151. event_client=None,
  152. react_max_steps=3)
  153. agent = service.create_agent(
  154. AgentCreateRequest(code="react-fn", name="React Function")
  155. )
  156. service.create_agent_config(
  157. AgentConfigCreateRequest(
  158. agent_id=agent.id,
  159. system_prompt="Use ReAct.",
  160. model_config={
  161. "react_enabled": True,
  162. "react_max_steps": 3,
  163. "function_calling_enabled": True,
  164. },
  165. memory_policy={"enabled": False, "write_enabled": False},
  166. tool_refs=[
  167. {"tool_code": "lookup_order", "required": True, "config_json": {}}
  168. ])
  169. )
  170. run = service.create_agent_run(
  171. AgentRunCreateRequest(
  172. agent_id=agent.id,
  173. input_text="check order")
  174. )
  175. result = service.execute_agent_run(
  176. agent_run_id=run.id,
  177. payload=AgentRunExecuteRequest(worker_key="test"))
  178. assert result is not None
  179. assert result.status == "completed"
  180. assert result.output_text == "Function call handled."
  181. assert result.output_json is not None
  182. assert result.output_json["react_steps"][0]["action"]["tool_call_protocol"] == "openai"
  183. assert result.output_json["react_steps"][0]["action"]["input_json"] == {"order_id": "123"}
  184. assert model_client.request_tools[0]
  185. engine.dispose()