|
|
@@ -1,8 +1,16 @@
|
|
|
from datetime import datetime, timedelta
|
|
|
+from typing import cast
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
-from core_domain import ChatCompletionRequestContract, ChatMessageContract
|
|
|
+from core_domain import (
|
|
|
+ ChatCompletionRequestContract,
|
|
|
+ ChatMessageContract,
|
|
|
+ MemoryCreateContract,
|
|
|
+ MemoryScopeType,
|
|
|
+ MemorySearchRequestContract,
|
|
|
+ MemorySearchResultContract,
|
|
|
+)
|
|
|
from core_shared import JSONValue
|
|
|
|
|
|
from app.bootstrap.settings import AgentServiceSettings
|
|
|
@@ -13,6 +21,7 @@ from app.domain.repositories import (
|
|
|
AgentVersionRepository,
|
|
|
)
|
|
|
from app.infrastructure.model_gateway_client import ModelGatewayClient, ModelGatewayClientError
|
|
|
+from app.infrastructure.memory_client import MemoryClient, MemoryClientError
|
|
|
from app.schemas.agent import (
|
|
|
AgentCreateRequest,
|
|
|
AgentRunCreateRequest,
|
|
|
@@ -31,11 +40,13 @@ class AgentApplicationService:
|
|
|
agent_version_repository: AgentVersionRepository,
|
|
|
agent_run_repository: AgentRunRepository,
|
|
|
model_gateway_client: ModelGatewayClient | None = None,
|
|
|
+ memory_client: MemoryClient | None = None,
|
|
|
) -> None:
|
|
|
self.agent_repository = agent_repository
|
|
|
self.agent_version_repository = agent_version_repository
|
|
|
self.agent_run_repository = agent_run_repository
|
|
|
self.model_gateway_client = model_gateway_client
|
|
|
+ self.memory_client = memory_client
|
|
|
|
|
|
def create_agent(self, payload: AgentCreateRequest) -> AgentDefinition:
|
|
|
return self.agent_repository.create(
|
|
|
@@ -172,7 +183,15 @@ class AgentApplicationService:
|
|
|
worker_key=payload.worker_key,
|
|
|
)
|
|
|
|
|
|
- messages = self._build_chat_messages(agent_run=agent_run, agent_version=agent_version)
|
|
|
+ memory_results, memory_metadata = self._read_relevant_memories(
|
|
|
+ agent_run=agent_run,
|
|
|
+ agent_version=agent_version,
|
|
|
+ )
|
|
|
+ messages = self._build_chat_messages(
|
|
|
+ agent_run=agent_run,
|
|
|
+ agent_version=agent_version,
|
|
|
+ memory_results=memory_results,
|
|
|
+ )
|
|
|
if payload.dry_run:
|
|
|
return self.agent_run_repository.update_status(
|
|
|
agent_run_id=agent_run.id,
|
|
|
@@ -187,6 +206,7 @@ class AgentApplicationService:
|
|
|
"agent_version_id": agent_version.id,
|
|
|
"message_count": len(messages),
|
|
|
"messages": [message.model_dump(mode="json") for message in messages],
|
|
|
+ **memory_metadata,
|
|
|
},
|
|
|
)
|
|
|
|
|
|
@@ -226,6 +246,11 @@ class AgentApplicationService:
|
|
|
error_message=str(exc),
|
|
|
)
|
|
|
|
|
|
+ memory_write_metadata = self._write_interaction_memory(
|
|
|
+ agent_run=agent_run,
|
|
|
+ agent_version=agent_version,
|
|
|
+ output_text=response.content,
|
|
|
+ )
|
|
|
return self.agent_run_repository.update_status(
|
|
|
agent_run_id=agent_run.id,
|
|
|
status="completed",
|
|
|
@@ -238,6 +263,8 @@ class AgentApplicationService:
|
|
|
"finish_reason": response.finish_reason,
|
|
|
"usage_json": response.usage_json,
|
|
|
"raw_response_json": response.raw_response_json,
|
|
|
+ **memory_metadata,
|
|
|
+ **memory_write_metadata,
|
|
|
},
|
|
|
)
|
|
|
|
|
|
@@ -292,12 +319,20 @@ class AgentApplicationService:
|
|
|
*,
|
|
|
agent_run: AgentRun,
|
|
|
agent_version: AgentVersion,
|
|
|
+ memory_results: list[MemorySearchResultContract] | None = None,
|
|
|
) -> list[ChatMessageContract]:
|
|
|
messages = [
|
|
|
ChatMessageContract(role="system", content=agent_version.system_prompt),
|
|
|
]
|
|
|
if agent_version.goal:
|
|
|
messages.append(ChatMessageContract(role="system", content=f"Goal: {agent_version.goal}"))
|
|
|
+ if memory_results:
|
|
|
+ messages.append(
|
|
|
+ ChatMessageContract(
|
|
|
+ role="system",
|
|
|
+ content=self._format_memory_context(memory_results),
|
|
|
+ )
|
|
|
+ )
|
|
|
if agent_run.input_text:
|
|
|
messages.append(ChatMessageContract(role="user", content=agent_run.input_text))
|
|
|
if agent_run.input_json:
|
|
|
@@ -334,6 +369,196 @@ class AgentApplicationService:
|
|
|
return value
|
|
|
return None
|
|
|
|
|
|
+ def _read_relevant_memories(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ agent_run: AgentRun,
|
|
|
+ agent_version: AgentVersion,
|
|
|
+ ) -> tuple[list[MemorySearchResultContract], dict[str, JSONValue]]:
|
|
|
+ if self.memory_client is None:
|
|
|
+ return [], {"memory_read_enabled": False, "memory_read_reason": "client_missing"}
|
|
|
+ if not self._read_bool(agent_version.memory_policy_json, "enabled", default=True):
|
|
|
+ return [], {"memory_read_enabled": False, "memory_read_reason": "policy_disabled"}
|
|
|
+
|
|
|
+ query = agent_run.input_text or str(agent_run.input_json or "")
|
|
|
+ if not query:
|
|
|
+ return [], {"memory_read_enabled": True, "memory_read_count": 0}
|
|
|
+
|
|
|
+ scope = self._resolve_memory_scope(agent_run=agent_run, agent_version=agent_version)
|
|
|
+ if scope is None:
|
|
|
+ return [], {
|
|
|
+ "memory_read_enabled": True,
|
|
|
+ "memory_read_count": 0,
|
|
|
+ "memory_read_reason": "scope_unavailable",
|
|
|
+ }
|
|
|
+
|
|
|
+ scope_type, scope_id = scope
|
|
|
+ try:
|
|
|
+ results = self.memory_client.search_memories(
|
|
|
+ MemorySearchRequestContract(
|
|
|
+ tenant_id=agent_run.tenant_id,
|
|
|
+ query=query,
|
|
|
+ scope_type=scope_type,
|
|
|
+ scope_id=scope_id,
|
|
|
+ owner_agent_id=agent_run.agent_id,
|
|
|
+ session_id=agent_run.session_id,
|
|
|
+ limit=self._read_int(
|
|
|
+ agent_version.memory_policy_json,
|
|
|
+ "read_top_k",
|
|
|
+ default=8,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ except MemoryClientError as exc:
|
|
|
+ return [], {
|
|
|
+ "memory_read_enabled": True,
|
|
|
+ "memory_read_count": 0,
|
|
|
+ "memory_read_error": str(exc),
|
|
|
+ }
|
|
|
+
|
|
|
+ return results, {
|
|
|
+ "memory_read_enabled": True,
|
|
|
+ "memory_read_count": len(results),
|
|
|
+ "memory_scope_type": scope_type,
|
|
|
+ "memory_scope_id": scope_id,
|
|
|
+ }
|
|
|
+
|
|
|
+ def _write_interaction_memory(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ agent_run: AgentRun,
|
|
|
+ agent_version: AgentVersion,
|
|
|
+ output_text: str,
|
|
|
+ ) -> dict[str, JSONValue]:
|
|
|
+ if self.memory_client is None:
|
|
|
+ return {"memory_write_enabled": False, "memory_write_reason": "client_missing"}
|
|
|
+ if not self._read_bool(agent_version.memory_policy_json, "write_enabled", default=True):
|
|
|
+ return {"memory_write_enabled": False, "memory_write_reason": "policy_disabled"}
|
|
|
+
|
|
|
+ scope = self._resolve_memory_scope(agent_run=agent_run, agent_version=agent_version)
|
|
|
+ if scope is None:
|
|
|
+ return {"memory_write_enabled": True, "memory_write_reason": "scope_unavailable"}
|
|
|
+
|
|
|
+ scope_type, scope_id = scope
|
|
|
+ try:
|
|
|
+ memory = self.memory_client.create_memory(
|
|
|
+ MemoryCreateContract(
|
|
|
+ tenant_id=agent_run.tenant_id,
|
|
|
+ scope_type=scope_type,
|
|
|
+ scope_id=scope_id,
|
|
|
+ memory_type="conversation",
|
|
|
+ content_text=self._format_interaction_memory(
|
|
|
+ agent_run=agent_run,
|
|
|
+ output_text=output_text,
|
|
|
+ ),
|
|
|
+ content_json={
|
|
|
+ "agent_run_id": agent_run.id,
|
|
|
+ "agent_version_id": agent_version.id,
|
|
|
+ "input_text": agent_run.input_text,
|
|
|
+ "output_text": output_text,
|
|
|
+ },
|
|
|
+ metadata_json={
|
|
|
+ "source": "agent-service",
|
|
|
+ "role": agent_version.role,
|
|
|
+ "version_no": agent_version.version_no,
|
|
|
+ },
|
|
|
+ owner_agent_id=agent_run.agent_id,
|
|
|
+ session_id=agent_run.session_id,
|
|
|
+ source_ref=f"agent_run:{agent_run.id}",
|
|
|
+ importance_score=self._read_nested_int(
|
|
|
+ agent_version.memory_policy_json,
|
|
|
+ "config_json",
|
|
|
+ "write_importance_score",
|
|
|
+ default=50,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ except MemoryClientError as exc:
|
|
|
+ return {
|
|
|
+ "memory_write_enabled": True,
|
|
|
+ "memory_write_error": str(exc),
|
|
|
+ }
|
|
|
+
|
|
|
+ return {
|
|
|
+ "memory_write_enabled": True,
|
|
|
+ "memory_written_id": memory.id,
|
|
|
+ "memory_scope_type": scope_type,
|
|
|
+ "memory_scope_id": scope_id,
|
|
|
+ }
|
|
|
+
|
|
|
+ def _resolve_memory_scope(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ agent_run: AgentRun,
|
|
|
+ agent_version: AgentVersion,
|
|
|
+ ) -> tuple[MemoryScopeType, str] | None:
|
|
|
+ scope_value = self._read_optional_string(
|
|
|
+ agent_version.memory_policy_json,
|
|
|
+ "memory_scope",
|
|
|
+ ) or "session"
|
|
|
+ if scope_value == "tenant":
|
|
|
+ return "tenant", agent_run.tenant_id
|
|
|
+ if scope_value == "agent":
|
|
|
+ return "agent", agent_run.agent_id
|
|
|
+ if scope_value == "session" and agent_run.session_id:
|
|
|
+ return "session", agent_run.session_id
|
|
|
+ if scope_value == "user":
|
|
|
+ user_id = self._read_input_json_string(agent_run=agent_run, key="user_id")
|
|
|
+ if user_id is not None:
|
|
|
+ return "user", user_id
|
|
|
+ if scope_value == "team":
|
|
|
+ team_id = self._read_input_json_string(agent_run=agent_run, key="team_id")
|
|
|
+ if team_id is not None:
|
|
|
+ return "team", team_id
|
|
|
+ return None
|
|
|
+
|
|
|
+ def _format_memory_context(self, memory_results: list[MemorySearchResultContract]) -> str:
|
|
|
+ lines = ["Relevant memories:"]
|
|
|
+ for index, result in enumerate(memory_results, start=1):
|
|
|
+ lines.append(f"{index}. {result.item.content_text}")
|
|
|
+ return "\n".join(lines)
|
|
|
+
|
|
|
+ def _format_interaction_memory(self, *, agent_run: AgentRun, output_text: str) -> str:
|
|
|
+ input_text = agent_run.input_text or str(agent_run.input_json or {})
|
|
|
+ return f"User input: {input_text}\nAgent output: {output_text}"
|
|
|
+
|
|
|
+ def _read_bool(self, payload: dict[str, JSONValue], key: str, *, default: bool) -> bool:
|
|
|
+ value = payload.get(key)
|
|
|
+ if isinstance(value, bool):
|
|
|
+ return value
|
|
|
+ return default
|
|
|
+
|
|
|
+ def _read_int(self, payload: dict[str, JSONValue], key: str, *, default: int) -> int:
|
|
|
+ value = payload.get(key)
|
|
|
+ if isinstance(value, int) and not isinstance(value, bool):
|
|
|
+ return value
|
|
|
+ return default
|
|
|
+
|
|
|
+ def _read_nested_int(
|
|
|
+ self,
|
|
|
+ payload: dict[str, JSONValue],
|
|
|
+ parent_key: str,
|
|
|
+ child_key: str,
|
|
|
+ *,
|
|
|
+ default: int,
|
|
|
+ ) -> int:
|
|
|
+ parent_value = payload.get(parent_key)
|
|
|
+ if not isinstance(parent_value, dict):
|
|
|
+ return default
|
|
|
+ return self._read_int(
|
|
|
+ cast(dict[str, JSONValue], parent_value),
|
|
|
+ child_key,
|
|
|
+ default=default,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _read_input_json_string(self, *, agent_run: AgentRun, key: str) -> str | None:
|
|
|
+ if agent_run.input_json is None:
|
|
|
+ return None
|
|
|
+ value = agent_run.input_json.get(key)
|
|
|
+ if isinstance(value, str) and value:
|
|
|
+ return value
|
|
|
+ return None
|
|
|
+
|
|
|
|
|
|
def build_agent_application_service(
|
|
|
*,
|
|
|
@@ -348,4 +573,8 @@ def build_agent_application_service(
|
|
|
base_url=settings.model_gateway_service_url,
|
|
|
timeout_seconds=settings.model_gateway_timeout_seconds,
|
|
|
),
|
|
|
+ memory_client=MemoryClient(
|
|
|
+ base_url=settings.memory_service_url,
|
|
|
+ timeout_seconds=settings.memory_service_timeout_seconds,
|
|
|
+ ),
|
|
|
)
|