Browse Source

feat: enhance react tools and add ci tests

Jax Docker 1 month ago
parent
commit
0e113963be

+ 18 - 0
.gitlab-ci.yml

@@ -0,0 +1,18 @@
+stages:
+  - test
+
+python-test:
+  image: python:3.11-slim
+  stage: test
+  before_script:
+    - python -m pip install --upgrade pip
+    - pip install pytest
+    - pip install -e libs/core-shared
+    - pip install -e libs/core-domain
+    - pip install -e libs/core-db
+    - pip install -e libs/core-events
+    - pip install -e services/agent-service
+    - pip install -e services/knowledge-service
+  script:
+    - python -m compileall libs services scripts tests
+    - pytest -q

+ 11 - 0
README.md

@@ -1246,6 +1246,17 @@ Run only selected migrations:
 python .\scripts\migrate_all.py --only agent-service --only runtime-service
 ```
 
+Run the automated smoke tests:
+
+```powershell
+pip install pytest
+pytest -q
+```
+
+The repository includes `.gitlab-ci.yml` with a Python 3.11 test job that
+installs the core libraries plus Agent/Knowledge services, runs `compileall`,
+and executes the pytest smoke suite.
+
 Scale runtime workers:
 
 ```powershell

+ 2 - 0
libs/core-domain/src/core_domain/agent_contracts.py

@@ -18,6 +18,8 @@ class AgentModelConfigContract(BaseModel):
     max_tokens: int | None = None
     react_enabled: bool = False
     react_max_steps: int | None = None
+    react_max_tool_calls: int | None = None
+    react_tool_retry_count: int | None = None
     extra_json: dict[str, JSONValue] = Field(default_factory=dict)
 
 

+ 87 - 3
services/agent-service/app/application/services.py

@@ -53,6 +53,8 @@ class AgentApplicationService:
         skill_client: SkillServiceClient | None = None,
         event_client: EventServiceClient | None = None,
         react_max_steps: int = 5,
+        react_max_tool_calls: int = 10,
+        react_tool_retry_count: int = 1,
     ) -> None:
         self.agent_repository = agent_repository
         self.agent_version_repository = agent_version_repository
@@ -64,6 +66,8 @@ class AgentApplicationService:
         self.skill_client = skill_client
         self.event_client = event_client
         self.react_max_steps = react_max_steps
+        self.react_max_tool_calls = react_max_tool_calls
+        self.react_tool_retry_count = react_tool_retry_count
 
     def create_agent(self, payload: AgentCreateRequest) -> AgentDefinition:
         return self.agent_repository.create(
@@ -435,6 +439,7 @@ class AgentApplicationService:
             agent_version=agent_version,
             memory_results=memory_results,
             capability_context=self._format_react_instruction(
+                agent_run=agent_run,
                 selected_tools=selected_tools,
                 skill_invocations=skill_invocations,
             ),
@@ -442,6 +447,7 @@ class AgentApplicationService:
         react_steps: list[dict[str, JSONValue]] = []
         tool_invocations: list[dict[str, JSONValue]] = []
         final_answer: str | None = None
+        tool_call_count = 0
 
         max_steps = self._read_int(
             agent_version.model_config_json,
@@ -488,6 +494,16 @@ class AgentApplicationService:
                 final_answer = response.content
                 break
 
+            max_tool_calls = self._read_int(
+                agent_version.model_config_json,
+                "react_max_tool_calls",
+                default=self.react_max_tool_calls,
+            )
+            if tool_call_count >= max(max_tool_calls, 0):
+                final_answer = "Tool call budget exhausted."
+                react_step["observation"] = final_answer
+                break
+
             tool_code = action.get("tool_code")
             matching_tools = [
                 item for item in selected_tools if item.tool_code == tool_code
@@ -505,11 +521,12 @@ class AgentApplicationService:
                 agent_run.input_json = {
                     str(item_key): item_value for item_key, item_value in tool_input.items()
                 }
-            current_invocations = self._invoke_selected_tools(
+            current_invocations = self._invoke_react_tool_with_retry(
                 agent_run=agent_run,
                 agent_version=agent_version,
-                selected_tools=matching_tools[:1],
+                tool_ref=matching_tools[0],
             )
+            tool_call_count += len(current_invocations)
             agent_run.input_json = original_input_json
             tool_invocations.extend(current_invocations)
             observation = self._format_react_observation(current_invocations)
@@ -535,6 +552,7 @@ class AgentApplicationService:
                 "agent_version_id": agent_version.id,
                 "react_enabled": True,
                 "react_steps": react_steps,
+                "react_tool_call_count": tool_call_count,
                 "tool_invocations": tool_invocations,
                 "skill_invocations": skill_invocations,
                 **memory_metadata,
@@ -931,19 +949,85 @@ class AgentApplicationService:
     def _format_react_instruction(
         self,
         *,
+        agent_run: AgentRun,
         selected_tools: list[AgentToolRefContract],
         skill_invocations: list[dict[str, JSONValue]],
     ) -> str:
+        tool_schemas = self._build_react_tool_schemas(
+            agent_run=agent_run,
+            selected_tools=selected_tools,
+        )
         return (
             "Use ReAct JSON only. Respond with one JSON object per turn.\n"
             "To call a tool: "
             '{"action":"tool","tool_code":"code","input_json":{...}}\n'
             "To finish: "
             '{"action":"finish","answer":"final answer"}\n'
-            f"Available tools: {[item.model_dump(mode='json') for item in selected_tools]}\n"
+            f"Available tools: {tool_schemas}\n"
             f"Pre-run skill results: {skill_invocations}"
         )
 
+    def _build_react_tool_schemas(
+        self,
+        *,
+        agent_run: AgentRun,
+        selected_tools: list[AgentToolRefContract],
+    ) -> list[dict[str, JSONValue]]:
+        schemas: list[dict[str, JSONValue]] = []
+        for ref in selected_tools:
+            schema: dict[str, JSONValue] = {
+                "tool_code": ref.tool_code,
+                "tool_binding_id": ref.tool_binding_id,
+                "required": ref.required,
+                "config_json": ref.config_json,
+            }
+            if ref.tool_binding_id is not None and self.tool_client is not None:
+                try:
+                    detail = self.tool_client.get_tool_binding_detail(
+                        tenant_id=agent_run.tenant_id,
+                        binding_id=ref.tool_binding_id,
+                    )
+                    schema.update(
+                        {
+                            "name": detail.tool_definition.name,
+                            "description": detail.tool_definition.description,
+                            "tool_type": detail.tool_definition.tool_type,
+                            "input_schema_json": detail.tool_version.input_schema_json or {},
+                            "output_schema_json": detail.tool_version.output_schema_json or {},
+                            "timeout_ms": detail.tool_version.timeout_ms,
+                        }
+                    )
+                except ToolServiceClientError as exc:
+                    schema["schema_error"] = str(exc)
+            schemas.append(schema)
+        return schemas
+
+    def _invoke_react_tool_with_retry(
+        self,
+        *,
+        agent_run: AgentRun,
+        agent_version: AgentVersion,
+        tool_ref: AgentToolRefContract,
+    ) -> list[dict[str, JSONValue]]:
+        retry_count = self._read_int(
+            agent_version.model_config_json,
+            "react_tool_retry_count",
+            default=self.react_tool_retry_count,
+        )
+        attempts: list[dict[str, JSONValue]] = []
+        for attempt_index in range(max(retry_count, 0) + 1):
+            current = self._invoke_selected_tools(
+                agent_run=agent_run,
+                agent_version=agent_version,
+                selected_tools=[tool_ref],
+            )
+            for item in current:
+                item["attempt_index"] = attempt_index
+            attempts.extend(current)
+            if current and current[-1].get("status") == "completed":
+                break
+        return attempts
+
     def _format_react_observation(
         self,
         tool_invocations: list[dict[str, JSONValue]],

+ 2 - 0
services/agent-service/app/bootstrap/settings.py

@@ -20,3 +20,5 @@ class AgentServiceSettings(ServiceSettings):
     worker_max_idle_cycles: int | None = None
     worker_dry_run: bool = False
     react_max_steps: int = 5
+    react_max_tool_calls: int = 10
+    react_tool_retry_count: int = 1

+ 123 - 0
tests/test_agent_react.py

@@ -0,0 +1,123 @@
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+for module_name in list(sys.modules):
+    if module_name == "app" or module_name.startswith("app."):
+        del sys.modules[module_name]
+for path in [
+    REPO_ROOT / "libs" / "core-domain" / "src",
+    REPO_ROOT / "libs" / "core-shared" / "src",
+    REPO_ROOT / "libs" / "core-db" / "src",
+    REPO_ROOT / "libs" / "core-events" / "src",
+    REPO_ROOT / "services" / "agent-service",
+]:
+    sys.path.insert(0, str(path))
+
+from core_db import Base
+from core_domain import ChatCompletionResponseContract
+
+from app.application.services import AgentApplicationService
+from app.bootstrap.settings import AgentServiceSettings
+from app.db.session import build_session_factory
+from app.domain.repositories import (
+    AgentDefinitionRepository,
+    AgentRunRepository,
+    AgentToolInvocationRepository,
+    AgentVersionRepository,
+)
+from app.schemas.agent import (
+    AgentCreateRequest,
+    AgentRunCreateRequest,
+    AgentRunExecuteRequest,
+    AgentVersionCreateRequest,
+)
+
+
+class FakeModelClient:
+    def __init__(self) -> None:
+        self.calls = 0
+
+    def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
+        self.calls += 1
+        if self.calls == 1:
+            return ChatCompletionResponseContract(
+                model="fake",
+                content=(
+                    '{"action":"tool","tool_code":"lookup_order",'
+                    '"input_json":{"order_id":"123"}}'
+                ),
+            )
+        return ChatCompletionResponseContract(
+            model="fake",
+            content='{"action":"finish","answer":"Order lookup attempted."}',
+        )
+
+
+def test_react_loop_records_steps_and_tool_invocation(tmp_path: Path) -> None:
+    session_factory = build_session_factory(
+        settings=AgentServiceSettings(
+            database_url=f"sqlite:///{tmp_path / 'agent_service.db'}",
+        ),
+    )
+    engine = session_factory.kw["bind"]
+    Base.metadata.create_all(bind=engine)
+
+    with session_factory() as db:
+        service = AgentApplicationService(
+            agent_repository=AgentDefinitionRepository(db),
+            agent_version_repository=AgentVersionRepository(db),
+            agent_run_repository=AgentRunRepository(db),
+            agent_tool_invocation_repository=AgentToolInvocationRepository(db),
+            model_gateway_client=FakeModelClient(),
+            memory_client=None,
+            tool_client=None,
+            skill_client=None,
+            event_client=None,
+            react_max_steps=3,
+        )
+        agent = service.create_agent(
+            AgentCreateRequest(tenant_id="t1", code="react", name="React")
+        )
+        service.create_agent_version(
+            AgentVersionCreateRequest(
+                tenant_id="t1",
+                agent_id=agent.id,
+                status="published",
+                system_prompt="Use ReAct.",
+                model_config={"react_enabled": True, "react_max_steps": 3},
+                memory_policy={"enabled": False, "write_enabled": False},
+                tool_refs=[
+                    {"tool_code": "lookup_order", "required": True, "config_json": {}}
+                ],
+            )
+        )
+        run = service.create_agent_run(
+            AgentRunCreateRequest(
+                tenant_id="t1",
+                agent_id=agent.id,
+                input_text="check order",
+            )
+        )
+
+        result = service.execute_agent_run(
+            agent_run_id=run.id,
+            payload=AgentRunExecuteRequest(tenant_id="t1", worker_key="test"),
+        )
+
+        assert result is not None
+        assert result.status == "completed"
+        assert result.output_text == "Order lookup attempted."
+        assert result.output_json is not None
+        assert result.output_json["react_enabled"] is True
+        assert len(result.output_json["react_steps"]) == 2
+        invocations = service.list_agent_tool_invocations(
+            tenant_id="t1",
+            agent_run_id=run.id,
+        )
+        assert len(invocations) == 2
+        assert all(item.status == "skipped" for item in invocations)
+    engine.dispose()

+ 77 - 0
tests/test_knowledge_pgvector_fallback.py

@@ -0,0 +1,77 @@
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+for module_name in list(sys.modules):
+    if module_name == "app" or module_name.startswith("app."):
+        del sys.modules[module_name]
+for path in [
+    REPO_ROOT / "libs" / "core-domain" / "src",
+    REPO_ROOT / "libs" / "core-shared" / "src",
+    REPO_ROOT / "libs" / "core-db" / "src",
+    REPO_ROOT / "services" / "knowledge-service",
+]:
+    sys.path.insert(0, str(path))
+
+from core_db import Base
+
+from app.application.services import KnowledgeApplicationService
+from app.bootstrap.settings import KnowledgeServiceSettings
+from app.db.session import build_session_factory
+from app.domain.repositories import (
+    KnowledgeBaseRepository,
+    KnowledgeChunkRepository,
+    KnowledgeDocumentRepository,
+)
+from app.schemas.knowledge import (
+    KnowledgeBaseCreateRequest,
+    KnowledgeDocumentCreateRequest,
+    KnowledgeSearchRequest,
+)
+
+
+def test_knowledge_search_falls_back_without_pgvector(tmp_path: Path) -> None:
+    settings = KnowledgeServiceSettings(
+        database_url=f"sqlite:///{tmp_path / 'knowledge_service.db'}",
+        embedding_provider="local",
+    )
+    session_factory = build_session_factory(settings)
+    Base.metadata.create_all(bind=session_factory.kw["bind"])
+
+    with session_factory() as db:
+        service = KnowledgeApplicationService(
+            settings=settings,
+            base_repository=KnowledgeBaseRepository(db),
+            document_repository=KnowledgeDocumentRepository(db),
+            chunk_repository=KnowledgeChunkRepository(db),
+        )
+        base = service.create_base(
+            KnowledgeBaseCreateRequest(tenant_id="t1", code="kb", name="KB")
+        )
+        _, chunks = service.create_document(
+            KnowledgeDocumentCreateRequest(
+                tenant_id="t1",
+                knowledge_base_id=base.id,
+                title="Refund Policy",
+                content_text="Refunds are available within seven days for eligible orders.",
+                chunk_size=40,
+                chunk_overlap=5,
+            )
+        )
+        assert chunks[0].embedding_vector is not None
+
+        results = service.search(
+            KnowledgeSearchRequest(
+                tenant_id="t1",
+                knowledge_base_id=base.id,
+                query="refund seven days",
+                top_k=3,
+            )
+        )
+
+        assert results
+        assert results[0][3]["retrieval_mode"] == "hybrid"
+    session_factory.kw["bind"].dispose()