浏览代码

feat: integrate agent memory

Jax Docker 1 月之前
父节点
当前提交
a688c3d241

+ 17 - 0
README.md

@@ -291,6 +291,23 @@ Invoke-RestMethod -Method Post `
   -Body '{"tenant_id":"t1","worker_key":"agent-worker-1"}'
 ```
 
+Agent memory policy is stored on `agent_version.memory_policy_json`:
+
+- `enabled`: read memories before execution
+- `memory_scope`: one of `tenant`, `user`, `session`, `agent`, or `team`
+- `read_top_k`: maximum memories to inject into the prompt
+- `write_enabled`: write a conversation memory after successful model execution
+- `config_json.write_importance_score`: optional importance score for written memories
+
+Example version with session memory:
+
+```powershell
+Invoke-RestMethod -Method Post `
+  -Uri http://127.0.0.1:8007/agents/versions `
+  -ContentType "application/json" `
+  -Body '{"tenant_id":"t1","agent_id":"agent-id","status":"published","role":"assistant","system_prompt":"Use relevant memory when helpful.","memory_policy":{"enabled":true,"memory_scope":"session","read_top_k":5,"write_enabled":true,"config_json":{"write_importance_score":60}}}'
+```
+
 Execute one queued agent run through the worker claim API:
 
 ```powershell

+ 6 - 0
deployments/docker/docker-compose.yml

@@ -113,6 +113,7 @@ services:
     environment:
       AGENT_PLATFORM_DATABASE_URL: sqlite:////data/agent_service.db
       AGENT_PLATFORM_MODEL_GATEWAY_SERVICE_URL: http://model-gateway-service:8005
+      AGENT_PLATFORM_MEMORY_SERVICE_URL: http://memory-service:8008
     ports:
       - "8007:8007"
     volumes:
@@ -120,6 +121,8 @@ services:
     depends_on:
       model-gateway-service:
         condition: service_started
+      memory-service:
+        condition: service_started
     healthcheck:
       test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8007/agents/health').read()"]
       interval: 15s
@@ -136,6 +139,7 @@ services:
     environment:
       AGENT_PLATFORM_DATABASE_URL: sqlite:////data/agent_service.db
       AGENT_PLATFORM_MODEL_GATEWAY_SERVICE_URL: http://model-gateway-service:8005
+      AGENT_PLATFORM_MEMORY_SERVICE_URL: http://memory-service:8008
       AGENT_PLATFORM_WORKER_POLL_INTERVAL_SECONDS: ${AGENT_PLATFORM_WORKER_POLL_INTERVAL_SECONDS:-1}
       AGENT_PLATFORM_WORKER_LEASE_SECONDS: ${AGENT_PLATFORM_WORKER_LEASE_SECONDS:-300}
       AGENT_PLATFORM_WORKER_DRY_RUN: ${AGENT_PLATFORM_AGENT_WORKER_DRY_RUN:-false}
@@ -144,6 +148,8 @@ services:
     depends_on:
       model-gateway-service:
         condition: service_started
+      memory-service:
+        condition: service_started
 
   memory-service:
     build:

+ 4 - 0
libs/core-domain/src/core_domain/__init__.py

@@ -23,8 +23,10 @@ from .model_contracts import (
     ChatMessageContract,
 )
 from .memory_contracts import (
+    MemoryCreateContract,
     MemoryItemContract,
     MemoryScopeType,
+    MemorySearchRequestContract,
     MemorySearchResultContract,
     MemoryStatus,
 )
@@ -65,8 +67,10 @@ __all__ = [
     "ChatCompletionResponseContract",
     "ChatMessageContract",
     "InitialNodeContract",
+    "MemoryCreateContract",
     "MemoryItemContract",
     "MemoryScopeType",
+    "MemorySearchRequestContract",
     "MemorySearchResultContract",
     "MemoryStatus",
     "NodeExecutionContextContract",

+ 27 - 0
libs/core-domain/src/core_domain/memory_contracts.py

@@ -30,6 +30,33 @@ class MemoryItemContract(BaseModel):
     created_time: datetime
 
 
+class MemoryCreateContract(BaseModel):
+    tenant_id: str
+    scope_type: MemoryScopeType
+    scope_id: str
+    memory_type: str = "fact"
+    content_text: str
+    content_json: dict[str, JSONValue] | None = None
+    metadata_json: dict[str, JSONValue] = Field(default_factory=dict)
+    owner_agent_id: str | None = None
+    user_id: str | None = None
+    session_id: str | None = None
+    source_ref: str | None = None
+    importance_score: int = Field(default=0, ge=0, le=100)
+    expires_time: datetime | None = None
+
+
+class MemorySearchRequestContract(BaseModel):
+    tenant_id: str
+    query: str
+    scope_type: MemoryScopeType | None = None
+    scope_id: str | None = None
+    owner_agent_id: str | None = None
+    user_id: str | None = None
+    session_id: str | None = None
+    limit: int = Field(default=8, ge=1, le=100)
+
+
 class MemorySearchResultContract(BaseModel):
     item: MemoryItemContract
     score: float

+ 231 - 2
services/agent-service/app/application/services.py

@@ -1,8 +1,16 @@
 from datetime import datetime, timedelta
+from typing import cast
 
 from sqlalchemy.orm import Session
 
-from core_domain import ChatCompletionRequestContract, ChatMessageContract
+from core_domain import (
+    ChatCompletionRequestContract,
+    ChatMessageContract,
+    MemoryCreateContract,
+    MemoryScopeType,
+    MemorySearchRequestContract,
+    MemorySearchResultContract,
+)
 from core_shared import JSONValue
 
 from app.bootstrap.settings import AgentServiceSettings
@@ -13,6 +21,7 @@ from app.domain.repositories import (
     AgentVersionRepository,
 )
 from app.infrastructure.model_gateway_client import ModelGatewayClient, ModelGatewayClientError
+from app.infrastructure.memory_client import MemoryClient, MemoryClientError
 from app.schemas.agent import (
     AgentCreateRequest,
     AgentRunCreateRequest,
@@ -31,11 +40,13 @@ class AgentApplicationService:
         agent_version_repository: AgentVersionRepository,
         agent_run_repository: AgentRunRepository,
         model_gateway_client: ModelGatewayClient | None = None,
+        memory_client: MemoryClient | None = None,
     ) -> None:
         self.agent_repository = agent_repository
         self.agent_version_repository = agent_version_repository
         self.agent_run_repository = agent_run_repository
         self.model_gateway_client = model_gateway_client
+        self.memory_client = memory_client
 
     def create_agent(self, payload: AgentCreateRequest) -> AgentDefinition:
         return self.agent_repository.create(
@@ -172,7 +183,15 @@ class AgentApplicationService:
             worker_key=payload.worker_key,
         )
 
-        messages = self._build_chat_messages(agent_run=agent_run, agent_version=agent_version)
+        memory_results, memory_metadata = self._read_relevant_memories(
+            agent_run=agent_run,
+            agent_version=agent_version,
+        )
+        messages = self._build_chat_messages(
+            agent_run=agent_run,
+            agent_version=agent_version,
+            memory_results=memory_results,
+        )
         if payload.dry_run:
             return self.agent_run_repository.update_status(
                 agent_run_id=agent_run.id,
@@ -187,6 +206,7 @@ class AgentApplicationService:
                     "agent_version_id": agent_version.id,
                     "message_count": len(messages),
                     "messages": [message.model_dump(mode="json") for message in messages],
+                    **memory_metadata,
                 },
             )
 
@@ -226,6 +246,11 @@ class AgentApplicationService:
                 error_message=str(exc),
             )
 
+        memory_write_metadata = self._write_interaction_memory(
+            agent_run=agent_run,
+            agent_version=agent_version,
+            output_text=response.content,
+        )
         return self.agent_run_repository.update_status(
             agent_run_id=agent_run.id,
             status="completed",
@@ -238,6 +263,8 @@ class AgentApplicationService:
                 "finish_reason": response.finish_reason,
                 "usage_json": response.usage_json,
                 "raw_response_json": response.raw_response_json,
+                **memory_metadata,
+                **memory_write_metadata,
             },
         )
 
@@ -292,12 +319,20 @@ class AgentApplicationService:
         *,
         agent_run: AgentRun,
         agent_version: AgentVersion,
+        memory_results: list[MemorySearchResultContract] | None = None,
     ) -> list[ChatMessageContract]:
         messages = [
             ChatMessageContract(role="system", content=agent_version.system_prompt),
         ]
         if agent_version.goal:
             messages.append(ChatMessageContract(role="system", content=f"Goal: {agent_version.goal}"))
+        if memory_results:
+            messages.append(
+                ChatMessageContract(
+                    role="system",
+                    content=self._format_memory_context(memory_results),
+                )
+            )
         if agent_run.input_text:
             messages.append(ChatMessageContract(role="user", content=agent_run.input_text))
         if agent_run.input_json:
@@ -334,6 +369,196 @@ class AgentApplicationService:
             return value
         return None
 
+    def _read_relevant_memories(
+        self,
+        *,
+        agent_run: AgentRun,
+        agent_version: AgentVersion,
+    ) -> tuple[list[MemorySearchResultContract], dict[str, JSONValue]]:
+        if self.memory_client is None:
+            return [], {"memory_read_enabled": False, "memory_read_reason": "client_missing"}
+        if not self._read_bool(agent_version.memory_policy_json, "enabled", default=True):
+            return [], {"memory_read_enabled": False, "memory_read_reason": "policy_disabled"}
+
+        query = agent_run.input_text or str(agent_run.input_json or "")
+        if not query:
+            return [], {"memory_read_enabled": True, "memory_read_count": 0}
+
+        scope = self._resolve_memory_scope(agent_run=agent_run, agent_version=agent_version)
+        if scope is None:
+            return [], {
+                "memory_read_enabled": True,
+                "memory_read_count": 0,
+                "memory_read_reason": "scope_unavailable",
+            }
+
+        scope_type, scope_id = scope
+        try:
+            results = self.memory_client.search_memories(
+                MemorySearchRequestContract(
+                    tenant_id=agent_run.tenant_id,
+                    query=query,
+                    scope_type=scope_type,
+                    scope_id=scope_id,
+                    owner_agent_id=agent_run.agent_id,
+                    session_id=agent_run.session_id,
+                    limit=self._read_int(
+                        agent_version.memory_policy_json,
+                        "read_top_k",
+                        default=8,
+                    ),
+                )
+            )
+        except MemoryClientError as exc:
+            return [], {
+                "memory_read_enabled": True,
+                "memory_read_count": 0,
+                "memory_read_error": str(exc),
+            }
+
+        return results, {
+            "memory_read_enabled": True,
+            "memory_read_count": len(results),
+            "memory_scope_type": scope_type,
+            "memory_scope_id": scope_id,
+        }
+
+    def _write_interaction_memory(
+        self,
+        *,
+        agent_run: AgentRun,
+        agent_version: AgentVersion,
+        output_text: str,
+    ) -> dict[str, JSONValue]:
+        if self.memory_client is None:
+            return {"memory_write_enabled": False, "memory_write_reason": "client_missing"}
+        if not self._read_bool(agent_version.memory_policy_json, "write_enabled", default=True):
+            return {"memory_write_enabled": False, "memory_write_reason": "policy_disabled"}
+
+        scope = self._resolve_memory_scope(agent_run=agent_run, agent_version=agent_version)
+        if scope is None:
+            return {"memory_write_enabled": True, "memory_write_reason": "scope_unavailable"}
+
+        scope_type, scope_id = scope
+        try:
+            memory = self.memory_client.create_memory(
+                MemoryCreateContract(
+                    tenant_id=agent_run.tenant_id,
+                    scope_type=scope_type,
+                    scope_id=scope_id,
+                    memory_type="conversation",
+                    content_text=self._format_interaction_memory(
+                        agent_run=agent_run,
+                        output_text=output_text,
+                    ),
+                    content_json={
+                        "agent_run_id": agent_run.id,
+                        "agent_version_id": agent_version.id,
+                        "input_text": agent_run.input_text,
+                        "output_text": output_text,
+                    },
+                    metadata_json={
+                        "source": "agent-service",
+                        "role": agent_version.role,
+                        "version_no": agent_version.version_no,
+                    },
+                    owner_agent_id=agent_run.agent_id,
+                    session_id=agent_run.session_id,
+                    source_ref=f"agent_run:{agent_run.id}",
+                    importance_score=self._read_nested_int(
+                        agent_version.memory_policy_json,
+                        "config_json",
+                        "write_importance_score",
+                        default=50,
+                    ),
+                )
+            )
+        except MemoryClientError as exc:
+            return {
+                "memory_write_enabled": True,
+                "memory_write_error": str(exc),
+            }
+
+        return {
+            "memory_write_enabled": True,
+            "memory_written_id": memory.id,
+            "memory_scope_type": scope_type,
+            "memory_scope_id": scope_id,
+        }
+
+    def _resolve_memory_scope(
+        self,
+        *,
+        agent_run: AgentRun,
+        agent_version: AgentVersion,
+    ) -> tuple[MemoryScopeType, str] | None:
+        scope_value = self._read_optional_string(
+            agent_version.memory_policy_json,
+            "memory_scope",
+        ) or "session"
+        if scope_value == "tenant":
+            return "tenant", agent_run.tenant_id
+        if scope_value == "agent":
+            return "agent", agent_run.agent_id
+        if scope_value == "session" and agent_run.session_id:
+            return "session", agent_run.session_id
+        if scope_value == "user":
+            user_id = self._read_input_json_string(agent_run=agent_run, key="user_id")
+            if user_id is not None:
+                return "user", user_id
+        if scope_value == "team":
+            team_id = self._read_input_json_string(agent_run=agent_run, key="team_id")
+            if team_id is not None:
+                return "team", team_id
+        return None
+
+    def _format_memory_context(self, memory_results: list[MemorySearchResultContract]) -> str:
+        lines = ["Relevant memories:"]
+        for index, result in enumerate(memory_results, start=1):
+            lines.append(f"{index}. {result.item.content_text}")
+        return "\n".join(lines)
+
+    def _format_interaction_memory(self, *, agent_run: AgentRun, output_text: str) -> str:
+        input_text = agent_run.input_text or str(agent_run.input_json or {})
+        return f"User input: {input_text}\nAgent output: {output_text}"
+
+    def _read_bool(self, payload: dict[str, JSONValue], key: str, *, default: bool) -> bool:
+        value = payload.get(key)
+        if isinstance(value, bool):
+            return value
+        return default
+
+    def _read_int(self, payload: dict[str, JSONValue], key: str, *, default: int) -> int:
+        value = payload.get(key)
+        if isinstance(value, int) and not isinstance(value, bool):
+            return value
+        return default
+
+    def _read_nested_int(
+        self,
+        payload: dict[str, JSONValue],
+        parent_key: str,
+        child_key: str,
+        *,
+        default: int,
+    ) -> int:
+        parent_value = payload.get(parent_key)
+        if not isinstance(parent_value, dict):
+            return default
+        return self._read_int(
+            cast(dict[str, JSONValue], parent_value),
+            child_key,
+            default=default,
+        )
+
+    def _read_input_json_string(self, *, agent_run: AgentRun, key: str) -> str | None:
+        if agent_run.input_json is None:
+            return None
+        value = agent_run.input_json.get(key)
+        if isinstance(value, str) and value:
+            return value
+        return None
+
 
 def build_agent_application_service(
     *,
@@ -348,4 +573,8 @@ def build_agent_application_service(
             base_url=settings.model_gateway_service_url,
             timeout_seconds=settings.model_gateway_timeout_seconds,
         ),
+        memory_client=MemoryClient(
+            base_url=settings.memory_service_url,
+            timeout_seconds=settings.memory_service_timeout_seconds,
+        ),
     )

+ 2 - 0
services/agent-service/app/bootstrap/settings.py

@@ -7,6 +7,8 @@ class AgentServiceSettings(ServiceSettings):
     database_url: str = "sqlite:///./agent_service.db"
     model_gateway_service_url: str = "http://127.0.0.1:8005"
     model_gateway_timeout_seconds: float = 60.0
+    memory_service_url: str = "http://127.0.0.1:8008"
+    memory_service_timeout_seconds: float = 10.0
     worker_poll_interval_seconds: float = 1.0
     worker_lease_seconds: int = 300
     worker_max_idle_cycles: int | None = None

+ 48 - 0
services/agent-service/app/infrastructure/memory_client.py

@@ -0,0 +1,48 @@
+import httpx
+
+from core_domain import (
+    MemoryCreateContract,
+    MemoryItemContract,
+    MemorySearchRequestContract,
+    MemorySearchResultContract,
+)
+
+
+class MemoryClientError(Exception):
+    pass
+
+
+class MemoryClient:
+    def __init__(self, *, base_url: str, timeout_seconds: float = 10.0) -> None:
+        self.base_url = base_url.rstrip("/")
+        self.timeout_seconds = timeout_seconds
+
+    def create_memory(self, payload: MemoryCreateContract) -> MemoryItemContract:
+        try:
+            with httpx.Client(timeout=self.timeout_seconds) as client:
+                response = client.post(
+                    f"{self.base_url}/memories",
+                    json=payload.model_dump(mode="json"),
+                )
+                response.raise_for_status()
+                return MemoryItemContract.model_validate(response.json())
+        except httpx.HTTPError as exc:
+            raise MemoryClientError(f"memory-service create request failed: {exc}") from exc
+
+    def search_memories(
+        self,
+        payload: MemorySearchRequestContract,
+    ) -> list[MemorySearchResultContract]:
+        try:
+            with httpx.Client(timeout=self.timeout_seconds) as client:
+                response = client.post(
+                    f"{self.base_url}/memories/search",
+                    json=payload.model_dump(mode="json"),
+                )
+                response.raise_for_status()
+                return [
+                    MemorySearchResultContract.model_validate(item)
+                    for item in response.json()
+                ]
+        except httpx.HTTPError as exc:
+            raise MemoryClientError(f"memory-service search request failed: {exc}") from exc

+ 7 - 27
services/memory-service/app/schemas/memory.py

@@ -1,34 +1,21 @@
-from datetime import datetime
 from typing import TYPE_CHECKING
 
-from pydantic import BaseModel, Field
+from pydantic import BaseModel
 
 from core_domain import (
+    MemoryCreateContract,
     MemoryItemContract,
-    MemoryScopeType,
+    MemorySearchRequestContract,
     MemorySearchResultContract,
     MemoryStatus,
 )
-from core_shared import JSONValue
 
 if TYPE_CHECKING:
     from app.db.models import MemoryItem
 
 
-class MemoryCreateRequest(BaseModel):
-    tenant_id: str
-    scope_type: MemoryScopeType
-    scope_id: str
-    memory_type: str = "fact"
-    content_text: str
-    content_json: dict[str, JSONValue] | None = None
-    metadata_json: dict[str, JSONValue] = Field(default_factory=dict)
-    owner_agent_id: str | None = None
-    user_id: str | None = None
-    session_id: str | None = None
-    source_ref: str | None = None
-    importance_score: int = Field(default=0, ge=0, le=100)
-    expires_time: datetime | None = None
+class MemoryCreateRequest(MemoryCreateContract):
+    pass
 
 
 class MemoryStatusUpdateRequest(BaseModel):
@@ -36,15 +23,8 @@ class MemoryStatusUpdateRequest(BaseModel):
     status: MemoryStatus
 
 
-class MemorySearchRequest(BaseModel):
-    tenant_id: str
-    query: str
-    scope_type: MemoryScopeType | None = None
-    scope_id: str | None = None
-    owner_agent_id: str | None = None
-    user_id: str | None = None
-    session_id: str | None = None
-    limit: int = Field(default=8, ge=1, le=100)
+class MemorySearchRequest(MemorySearchRequestContract):
+    pass
 
 
 class MemoryResponse(MemoryItemContract):