Sfoglia il codice sorgente

feat: support openai tool calls

Jax Docker 1 mese fa
parent
commit
b28b52a033

+ 2 - 0
libs/core-domain/src/core_domain/agent_contracts.py

@@ -20,6 +20,8 @@ class AgentModelConfigContract(BaseModel):
     react_max_steps: int | None = None
     react_max_tool_calls: int | None = None
     react_tool_retry_count: int | None = None
+    function_calling_enabled: bool = False
+    tool_calling_enabled: bool = False
     extra_json: dict[str, JSONValue] = Field(default_factory=dict)
 
 

+ 3 - 0
libs/core-domain/src/core_domain/model_contracts.py

@@ -14,6 +14,8 @@ class ChatCompletionRequestContract(BaseModel):
     messages: list[ChatMessageContract] = Field(default_factory=list)
     temperature: float | None = None
     max_tokens: int | None = None
+    tools_json: list[dict[str, JSONValue]] = Field(default_factory=list)
+    tool_choice: str | dict[str, JSONValue] | None = None
     metadata_json: dict[str, JSONValue] = Field(default_factory=dict)
 
 
@@ -21,5 +23,6 @@ class ChatCompletionResponseContract(BaseModel):
     model: str | None = None
     content: str
     finish_reason: str | None = None
+    tool_calls_json: list[dict[str, JSONValue]] = Field(default_factory=list)
     usage_json: dict[str, JSONValue] = Field(default_factory=dict)
     raw_response_json: dict[str, JSONValue] = Field(default_factory=dict)

+ 90 - 1
services/agent-service/app/application/services.py

@@ -9,6 +9,7 @@ from core_domain import (
     AgentSkillRefContract,
     AgentToolRefContract,
     ChatCompletionRequestContract,
+    ChatCompletionResponseContract,
     ChatMessageContract,
     MemoryCreateContract,
     MemoryScopeType,
@@ -469,6 +470,7 @@ class AgentApplicationService:
                         agent_run=agent_run,
                         agent_version=agent_version,
                         messages=messages,
+                        selected_tools=selected_tools,
                     )
                 )
             except ModelGatewayClientError as exc:
@@ -486,7 +488,7 @@ class AgentApplicationService:
                     },
                 )
 
-            action = self._parse_react_action(response.content)
+            action = self._parse_react_action_from_response(response)
             react_step: dict[str, JSONValue] = {
                 "step_index": step_index,
                 "model_content": response.content,
@@ -1058,13 +1060,63 @@ class AgentApplicationService:
             return {"action": "finish", "answer": content}
         return {str(item_key): item_value for item_key, item_value in value.items()}
 
+    def _parse_react_action_from_response(
+        self,
+        response: ChatCompletionResponseContract,
+    ) -> dict[str, JSONValue]:
+        if response.tool_calls_json:
+            action = self._parse_openai_tool_call(response.tool_calls_json[0])
+            if action is not None:
+                return action
+        return self._parse_react_action(response.content)
+
+    def _parse_openai_tool_call(
+        self,
+        tool_call: dict[str, JSONValue],
+    ) -> dict[str, JSONValue] | None:
+        function_value = tool_call.get("function")
+        if not isinstance(function_value, dict):
+            return None
+        tool_code = function_value.get("name")
+        if not isinstance(tool_code, str) or not tool_code:
+            return None
+        raw_arguments = function_value.get("arguments")
+        input_json: dict[str, JSONValue] = {}
+        if isinstance(raw_arguments, str) and raw_arguments:
+            try:
+                decoded = json.loads(raw_arguments)
+            except json.JSONDecodeError:
+                decoded = {"raw_arguments": raw_arguments}
+            if isinstance(decoded, dict):
+                input_json = {str(item_key): item_value for item_key, item_value in decoded.items()}
+        elif isinstance(raw_arguments, dict):
+            input_json = {str(item_key): item_value for item_key, item_value in raw_arguments.items()}
+        tool_call_id = tool_call.get("id")
+        return {
+            "action": "tool",
+            "tool_code": tool_code,
+            "input_json": input_json,
+            "tool_call_id": tool_call_id if isinstance(tool_call_id, str) else None,
+            "tool_call_protocol": "openai",
+        }
+
     def _build_chat_completion_request(
         self,
         *,
         agent_run: AgentRun,
         agent_version: AgentVersion,
         messages: list[ChatMessageContract],
+        selected_tools: list[AgentToolRefContract] | None = None,
     ) -> ChatCompletionRequestContract:
+        function_calling_enabled = self._read_bool(
+            agent_version.model_config_json,
+            "function_calling_enabled",
+            default=False,
+        ) or self._read_bool(
+            agent_version.model_config_json,
+            "tool_calling_enabled",
+            default=False,
+        )
         return ChatCompletionRequestContract(
             model=self._read_optional_string(agent_version.model_config_json, "model"),
             temperature=self._read_optional_float(
@@ -1076,6 +1128,15 @@ class AgentApplicationService:
                 "max_tokens",
             ),
             messages=messages,
+            tools_json=(
+                self._build_openai_tool_schemas(
+                    agent_run=agent_run,
+                    selected_tools=selected_tools or [],
+                )
+                if function_calling_enabled
+                else []
+            ),
+            tool_choice="auto" if function_calling_enabled and selected_tools else None,
             metadata_json={
                 "tenant_id": agent_run.tenant_id,
                 "agent_id": agent_run.agent_id,
@@ -1084,6 +1145,34 @@ class AgentApplicationService:
             },
         )
 
+    def _build_openai_tool_schemas(
+        self,
+        *,
+        agent_run: AgentRun,
+        selected_tools: list[AgentToolRefContract],
+    ) -> list[dict[str, JSONValue]]:
+        tool_schemas: list[dict[str, JSONValue]] = []
+        for schema in self._build_react_tool_schemas(
+            agent_run=agent_run,
+            selected_tools=selected_tools,
+        ):
+            tool_code = schema.get("tool_code")
+            if not isinstance(tool_code, str) or not tool_code:
+                continue
+            description = schema.get("description")
+            input_schema = schema.get("input_schema_json")
+            tool_schemas.append(
+                {
+                    "type": "function",
+                    "function": {
+                        "name": tool_code,
+                        "description": description if isinstance(description, str) else "",
+                        "parameters": input_schema if isinstance(input_schema, dict) else {},
+                    },
+                }
+            )
+        return tool_schemas
+
     def _build_skill_input_json(
         self,
         *,

+ 23 - 0
services/model-gateway-service/app/infrastructure/provider.py

@@ -29,6 +29,10 @@ class ModelProviderClient:
             request_payload["temperature"] = payload.temperature
         if payload.max_tokens is not None:
             request_payload["max_tokens"] = payload.max_tokens
+        if payload.tools_json:
+            request_payload["tools"] = payload.tools_json
+        if payload.tool_choice is not None:
+            request_payload["tool_choice"] = payload.tool_choice
 
         request_headers: dict[str, str] = {"content-type": "application/json"}
         if self.settings.provider_api_key:
@@ -48,11 +52,13 @@ class ModelProviderClient:
         response_json = _coerce_json_dict(response.json())
         content = _extract_response_content(response_json)
         finish_reason = _extract_finish_reason(response_json)
+        tool_calls_json = _extract_tool_calls_json(response_json)
         usage_json = _extract_usage_json(response_json)
         return ChatCompletionResponseContract(
             model=payload.model,
             content=content,
             finish_reason=finish_reason,
+            tool_calls_json=tool_calls_json,
             usage_json=usage_json,
             raw_response_json=response_json,
         )
@@ -91,6 +97,23 @@ def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
     return None
 
 
+def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
+    choices = payload.get("choices")
+    if isinstance(choices, list) and choices:
+        first_choice = choices[0]
+        if isinstance(first_choice, dict):
+            message = first_choice.get("message")
+            if isinstance(message, dict):
+                tool_calls = message.get("tool_calls")
+                if isinstance(tool_calls, list):
+                    return [
+                        {str(item_key): item_value for item_key, item_value in item.items()}
+                        for item in tool_calls
+                        if isinstance(item, dict)
+                    ]
+    return []
+
+
 def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
     usage = payload.get("usage")
     if isinstance(usage, dict):

+ 96 - 0
tests/test_agent_react.py

@@ -57,6 +57,36 @@ class FakeModelClient:
         )
 
 
+class FakeFunctionCallingModelClient:
+    def __init__(self) -> None:
+        self.calls = 0
+        self.request_tools: list[object] = []
+
+    def create_chat_completion(self, payload: object) -> ChatCompletionResponseContract:
+        self.calls += 1
+        self.request_tools.append(getattr(payload, "tools_json", []))
+        if self.calls == 1:
+            return ChatCompletionResponseContract(
+                model="fake",
+                content="",
+                finish_reason="tool_calls",
+                tool_calls_json=[
+                    {
+                        "id": "call_1",
+                        "type": "function",
+                        "function": {
+                            "name": "lookup_order",
+                            "arguments": '{"order_id":"123"}',
+                        },
+                    }
+                ],
+            )
+        return ChatCompletionResponseContract(
+            model="fake",
+            content='{"action":"finish","answer":"Function call handled."}',
+        )
+
+
 def test_react_loop_records_steps_and_tool_invocation(tmp_path: Path) -> None:
     session_factory = build_session_factory(
         settings=AgentServiceSettings(
@@ -121,3 +151,69 @@ def test_react_loop_records_steps_and_tool_invocation(tmp_path: Path) -> None:
         assert len(invocations) == 2
         assert all(item.status == "skipped" for item in invocations)
     engine.dispose()
+
+
+def test_react_loop_accepts_openai_tool_calls(tmp_path: Path) -> None:
+    session_factory = build_session_factory(
+        settings=AgentServiceSettings(
+            database_url=f"sqlite:///{tmp_path / 'agent_service.db'}",
+        ),
+    )
+    engine = session_factory.kw["bind"]
+    Base.metadata.create_all(bind=engine)
+
+    with session_factory() as db:
+        model_client = FakeFunctionCallingModelClient()
+        service = AgentApplicationService(
+            agent_repository=AgentDefinitionRepository(db),
+            agent_version_repository=AgentVersionRepository(db),
+            agent_run_repository=AgentRunRepository(db),
+            agent_tool_invocation_repository=AgentToolInvocationRepository(db),
+            model_gateway_client=model_client,
+            memory_client=None,
+            tool_client=None,
+            skill_client=None,
+            event_client=None,
+            react_max_steps=3,
+        )
+        agent = service.create_agent(
+            AgentCreateRequest(tenant_id="t1", code="react-fn", name="React Function")
+        )
+        service.create_agent_version(
+            AgentVersionCreateRequest(
+                tenant_id="t1",
+                agent_id=agent.id,
+                status="published",
+                system_prompt="Use ReAct.",
+                model_config={
+                    "react_enabled": True,
+                    "react_max_steps": 3,
+                    "function_calling_enabled": True,
+                },
+                memory_policy={"enabled": False, "write_enabled": False},
+                tool_refs=[
+                    {"tool_code": "lookup_order", "required": True, "config_json": {}}
+                ],
+            )
+        )
+        run = service.create_agent_run(
+            AgentRunCreateRequest(
+                tenant_id="t1",
+                agent_id=agent.id,
+                input_text="check order",
+            )
+        )
+
+        result = service.execute_agent_run(
+            agent_run_id=run.id,
+            payload=AgentRunExecuteRequest(tenant_id="t1", worker_key="test"),
+        )
+
+        assert result is not None
+        assert result.status == "completed"
+        assert result.output_text == "Function call handled."
+        assert result.output_json is not None
+        assert result.output_json["react_steps"][0]["action"]["tool_call_protocol"] == "openai"
+        assert result.output_json["react_steps"][0]["action"]["input_json"] == {"order_id": "123"}
+        assert model_client.request_tools[0]
+    engine.dispose()