test_agent_react.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from __future__ import annotations
  2. import sys
  3. from pathlib import Path
  4. REPO_ROOT = Path(__file__).resolve().parents[1]
  5. for module_name in list(sys.modules):
  6. if module_name == "app" or module_name.startswith("app."):
  7. del sys.modules[module_name]
  8. for path in [
  9. REPO_ROOT / "libs" / "core-domain" / "src",
  10. REPO_ROOT / "libs" / "core-shared" / "src",
  11. REPO_ROOT / "libs" / "core-db" / "src",
  12. REPO_ROOT / "libs" / "core-events" / "src",
  13. REPO_ROOT / "services" / "agent-service",
  14. ]:
  15. sys.path.insert(0, str(path))
  16. from app.application.services import AgentApplicationService
  17. from app.bootstrap.settings import AgentServiceSettings
  18. from app.db.session import build_session_factory
  19. from app.domain.repositories import (
  20. AgentDefinitionRepository,
  21. AgentRunRepository,
  22. AgentToolInvocationRepository,
  23. AgentVersionRepository,
  24. )
  25. from app.schemas.agent import (
  26. AgentCreateRequest,
  27. AgentRunCreateRequest,
  28. AgentRunExecuteRequest,
  29. AgentVersionCreateRequest,
  30. )
  31. from core_db import Base
  32. from core_domain import ChatCompletionResponseContract
  33. class FakeModelClient:
  34. def __init__(self) -> None:
  35. self.calls = 0
  36. def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
  37. self.calls += 1
  38. if self.calls == 1:
  39. return ChatCompletionResponseContract(
  40. model="fake",
  41. content=(
  42. '{"action":"tool","tool_code":"lookup_order",'
  43. '"input_json":{"order_id":"123"}}'
  44. ))
  45. return ChatCompletionResponseContract(
  46. model="fake",
  47. content='{"action":"finish","answer":"Order lookup attempted."}')
  48. class FakeFunctionCallingModelClient:
  49. def __init__(self) -> None:
  50. self.calls = 0
  51. self.request_tools: list[object] = []
  52. def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
  53. self.calls += 1
  54. self.request_tools.append(getattr(payload, "tools_json", []))
  55. if self.calls == 1:
  56. return ChatCompletionResponseContract(
  57. model="fake",
  58. content="",
  59. finish_reason="tool_calls",
  60. tool_calls_json=[
  61. {
  62. "id": "call_1",
  63. "type": "function",
  64. "function": {
  65. "name": "lookup_order",
  66. "arguments": '{"order_id":"123"}',
  67. },
  68. }
  69. ])
  70. return ChatCompletionResponseContract(
  71. model="fake",
  72. content='{"action":"finish","answer":"Function call handled."}')
  73. def test_react_loop_records_steps_and_tool_invocation(tmp_path: Path) -> None:
  74. session_factory = build_session_factory(
  75. settings=AgentServiceSettings(
  76. database_url=f"sqlite:///{tmp_path / 'agent_service.db'}"))
  77. engine = session_factory.kw["bind"]
  78. Base.metadata.create_all(bind=engine)
  79. with session_factory() as db:
  80. service = AgentApplicationService(
  81. agent_repository=AgentDefinitionRepository(db),
  82. agent_version_repository=AgentVersionRepository(db),
  83. agent_run_repository=AgentRunRepository(db),
  84. agent_tool_invocation_repository=AgentToolInvocationRepository(db),
  85. model_gateway_client=FakeModelClient(),
  86. memory_client=None,
  87. tool_client=None,
  88. skill_client=None,
  89. event_client=None,
  90. react_max_steps=3)
  91. agent = service.create_agent(
  92. AgentCreateRequest(code="react", name="React")
  93. )
  94. service.create_agent_version(
  95. AgentVersionCreateRequest(
  96. agent_id=agent.id,
  97. status="published",
  98. system_prompt="Use ReAct.",
  99. model_config={"react_enabled": True, "react_max_steps": 3},
  100. memory_policy={"enabled": False, "write_enabled": False},
  101. tool_refs=[
  102. {"tool_code": "lookup_order", "required": True, "config_json": {}}
  103. ])
  104. )
  105. run = service.create_agent_run(
  106. AgentRunCreateRequest(
  107. agent_id=agent.id,
  108. input_text="check order")
  109. )
  110. result = service.execute_agent_run(
  111. agent_run_id=run.id,
  112. payload=AgentRunExecuteRequest(worker_key="test"))
  113. assert result is not None
  114. assert result.status == "completed"
  115. assert result.output_text == "Order lookup attempted."
  116. assert result.output_json is not None
  117. assert result.output_json["react_enabled"] is True
  118. assert len(result.output_json["react_steps"]) == 2
  119. invocations = service.list_agent_tool_invocations(
  120. agent_run_id=run.id)
  121. assert len(invocations) == 2
  122. assert all(item.status == "skipped" for item in invocations)
  123. engine.dispose()
  124. def test_react_loop_accepts_openai_tool_calls(tmp_path: Path) -> None:
  125. session_factory = build_session_factory(
  126. settings=AgentServiceSettings(
  127. database_url=f"sqlite:///{tmp_path / 'agent_service.db'}"))
  128. engine = session_factory.kw["bind"]
  129. Base.metadata.create_all(bind=engine)
  130. with session_factory() as db:
  131. model_client = FakeFunctionCallingModelClient()
  132. service = AgentApplicationService(
  133. agent_repository=AgentDefinitionRepository(db),
  134. agent_version_repository=AgentVersionRepository(db),
  135. agent_run_repository=AgentRunRepository(db),
  136. agent_tool_invocation_repository=AgentToolInvocationRepository(db),
  137. model_gateway_client=model_client,
  138. memory_client=None,
  139. tool_client=None,
  140. skill_client=None,
  141. event_client=None,
  142. react_max_steps=3)
  143. agent = service.create_agent(
  144. AgentCreateRequest(code="react-fn", name="React Function")
  145. )
  146. service.create_agent_version(
  147. AgentVersionCreateRequest(
  148. agent_id=agent.id,
  149. status="published",
  150. system_prompt="Use ReAct.",
  151. model_config={
  152. "react_enabled": True,
  153. "react_max_steps": 3,
  154. "function_calling_enabled": True,
  155. },
  156. memory_policy={"enabled": False, "write_enabled": False},
  157. tool_refs=[
  158. {"tool_code": "lookup_order", "required": True, "config_json": {}}
  159. ])
  160. )
  161. run = service.create_agent_run(
  162. AgentRunCreateRequest(
  163. agent_id=agent.id,
  164. input_text="check order")
  165. )
  166. result = service.execute_agent_run(
  167. agent_run_id=run.id,
  168. payload=AgentRunExecuteRequest(worker_key="test"))
  169. assert result is not None
  170. assert result.status == "completed"
  171. assert result.output_text == "Function call handled."
  172. assert result.output_json is not None
  173. assert result.output_json["react_steps"][0]["action"]["tool_call_protocol"] == "openai"
  174. assert result.output_json["react_steps"][0]["action"]["input_json"] == {"order_id": "123"}
  175. assert model_client.request_tools[0]
  176. engine.dispose()