test_agent_react.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from __future__ import annotations
  2. import sys
  3. from pathlib import Path
  4. REPO_ROOT = Path(__file__).resolve().parents[1]
  5. for module_name in list(sys.modules):
  6. if module_name == "app" or module_name.startswith("app."):
  7. del sys.modules[module_name]
  8. for path in [
  9. REPO_ROOT / "libs" / "core-domain" / "src",
  10. REPO_ROOT / "libs" / "core-shared" / "src",
  11. REPO_ROOT / "libs" / "core-db" / "src",
  12. REPO_ROOT / "libs" / "core-events" / "src",
  13. REPO_ROOT / "services" / "agent-service",
  14. ]:
  15. sys.path.insert(0, str(path))
  16. from core_db import Base
  17. from core_domain import ChatCompletionResponseContract
  18. from app.application.services import AgentApplicationService
  19. from app.bootstrap.settings import AgentServiceSettings
  20. from app.db.session import build_session_factory
  21. from app.domain.repositories import (
  22. AgentDefinitionRepository,
  23. AgentRunRepository,
  24. AgentToolInvocationRepository,
  25. AgentVersionRepository,
  26. )
  27. from app.schemas.agent import (
  28. AgentCreateRequest,
  29. AgentRunCreateRequest,
  30. AgentRunExecuteRequest,
  31. AgentVersionCreateRequest,
  32. )
  33. class FakeModelClient:
  34. def __init__(self) -> None:
  35. self.calls = 0
  36. def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
  37. self.calls += 1
  38. if self.calls == 1:
  39. return ChatCompletionResponseContract(
  40. model="fake",
  41. content=(
  42. '{"action":"tool","tool_code":"lookup_order",'
  43. '"input_json":{"order_id":"123"}}'
  44. ),
  45. )
  46. return ChatCompletionResponseContract(
  47. model="fake",
  48. content='{"action":"finish","answer":"Order lookup attempted."}',
  49. )
  50. def test_react_loop_records_steps_and_tool_invocation(tmp_path: Path) -> None:
  51. session_factory = build_session_factory(
  52. settings=AgentServiceSettings(
  53. database_url=f"sqlite:///{tmp_path / 'agent_service.db'}",
  54. ),
  55. )
  56. engine = session_factory.kw["bind"]
  57. Base.metadata.create_all(bind=engine)
  58. with session_factory() as db:
  59. service = AgentApplicationService(
  60. agent_repository=AgentDefinitionRepository(db),
  61. agent_version_repository=AgentVersionRepository(db),
  62. agent_run_repository=AgentRunRepository(db),
  63. agent_tool_invocation_repository=AgentToolInvocationRepository(db),
  64. model_gateway_client=FakeModelClient(),
  65. memory_client=None,
  66. tool_client=None,
  67. skill_client=None,
  68. event_client=None,
  69. react_max_steps=3,
  70. )
  71. agent = service.create_agent(
  72. AgentCreateRequest(tenant_id="t1", code="react", name="React")
  73. )
  74. service.create_agent_version(
  75. AgentVersionCreateRequest(
  76. tenant_id="t1",
  77. agent_id=agent.id,
  78. status="published",
  79. system_prompt="Use ReAct.",
  80. model_config={"react_enabled": True, "react_max_steps": 3},
  81. memory_policy={"enabled": False, "write_enabled": False},
  82. tool_refs=[
  83. {"tool_code": "lookup_order", "required": True, "config_json": {}}
  84. ],
  85. )
  86. )
  87. run = service.create_agent_run(
  88. AgentRunCreateRequest(
  89. tenant_id="t1",
  90. agent_id=agent.id,
  91. input_text="check order",
  92. )
  93. )
  94. result = service.execute_agent_run(
  95. agent_run_id=run.id,
  96. payload=AgentRunExecuteRequest(tenant_id="t1", worker_key="test"),
  97. )
  98. assert result is not None
  99. assert result.status == "completed"
  100. assert result.output_text == "Order lookup attempted."
  101. assert result.output_json is not None
  102. assert result.output_json["react_enabled"] is True
  103. assert len(result.output_json["react_steps"]) == 2
  104. invocations = service.list_agent_tool_invocations(
  105. tenant_id="t1",
  106. agent_run_id=run.id,
  107. )
  108. assert len(invocations) == 2
  109. assert all(item.status == "skipped" for item in invocations)
  110. engine.dispose()