| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 |
- from __future__ import annotations
- import hashlib
- import json
- from datetime import datetime
- from typing import TYPE_CHECKING
- from core_domain import MemoryScopeType, MemoryStatus
- from sqlalchemy.orm import Session
- from core_shared import JSONValue, try_build_redis_client
- from core_shared.task_queue import TaskQueuePublisher
- from app.application.retrieval import (
- build_hash_embedding,
- cosine_similarity,
- keyword_score,
- rerank_score,
- )
- from app.bootstrap.settings import MemoryServiceSettings
- from app.db.models import MemoryItem
- from app.domain.repositories import MemoryItemRepository
- from app.schemas.memory import (
- MemoryCreateRequest,
- MemoryCreateRequestDto,
- MemoryListRequestDto,
- MemorySearchRequest,
- MemorySearchRequestDto,
- MemoryStatusUpdateRequest,
- MemoryUpdateRequestDto,
- )
- if TYPE_CHECKING:
- from redis import Redis
- class MemoryApplicationService:
- def __init__(
- self,
- *,
- memory_repository: MemoryItemRepository,
- settings: MemoryServiceSettings | None = None,
- redis_client: Redis | None = None,
- task_queue_publisher: TaskQueuePublisher | None = None) -> None:
- self.memory_repository = memory_repository
- self.settings = settings or MemoryServiceSettings()
- self.redis_client = redis_client
- self.task_queue_publisher = task_queue_publisher
- def create_memory(self, payload: MemoryCreateRequest) -> MemoryItem:
- embedding_json = build_hash_embedding(
- payload.content_text,
- dimensions=self.settings.embedding_dimensions)
- entity = self.memory_repository.create(
- scope_type=payload.scope_type,
- scope_id=payload.scope_id,
- memory_type=payload.memory_type,
- content_text=payload.content_text,
- content_json=payload.content_json,
- metadata_json=payload.metadata_json,
- embedding_model=self.settings.embedding_model,
- embedding_json=embedding_json,
- owner_agent_id=payload.owner_agent_id,
- user_id=payload.user_id,
- session_id=payload.session_id,
- source_ref=payload.source_ref,
- importance_score=payload.importance_score,
- expires_time=payload.expires_time)
- self._bump_search_cache_generation()
- return entity
- def create_memory_from_contract(self, payload: MemoryCreateRequestDto) -> MemoryItem:
- return self.create_memory(MemoryCreateRequest(
- scope_type=payload.scopeType,
- scope_id=payload.scopeId,
- memory_type=payload.memoryType,
- content_text=payload.contentText,
- content_json=payload.content,
- metadata_json=payload.metadata,
- owner_agent_id=payload.ownerAgentId,
- user_id=payload.userId,
- session_id=payload.sessionId,
- source_ref=payload.sourceRef,
- importance_score=payload.importanceScore,
- expires_time=payload.expiresTime))
- def list_memories(
- self,
- *,
- scope_type: MemoryScopeType | None = None,
- scope_id: str | None = None,
- status: MemoryStatus | None = "active",
- limit: int = 100) -> list[MemoryItem]:
- return self.memory_repository.list_by_scope(
- scope_type=scope_type,
- scope_id=scope_id,
- status=status,
- limit=limit)
- def list_memories_contract(self, payload: MemoryListRequestDto) -> tuple[list[MemoryItem], int]:
- return self.memory_repository.list_filtered(
- scope_type=payload.scopeType,
- scope_id=payload.scopeId,
- memory_type=payload.memoryType,
- status=payload.status,
- owner_agent_id=payload.ownerAgentId,
- user_id=payload.userId,
- session_id=payload.sessionId,
- keyword=payload.keyword,
- include_expired=payload.includeExpired,
- offset=payload.offset,
- limit=payload.pageSize)
- def search_memories(
- self,
- payload: MemorySearchRequest) -> list[tuple[MemoryItem, float, dict[str, float | str]]]:
- query_embedding = build_hash_embedding(
- payload.query,
- dimensions=self.settings.embedding_dimensions)
- candidates = self.memory_repository.search_candidates(
- scope_type=payload.scope_type,
- scope_id=payload.scope_id,
- owner_agent_id=payload.owner_agent_id,
- user_id=payload.user_id,
- session_id=payload.session_id,
- limit=max(payload.limit * 10, payload.limit))
- scored_items = [
- self._score(item=item, query=payload.query, query_embedding=query_embedding)
- for item in candidates
- ]
- scored_items.sort(key=lambda item: item[1], reverse=True)
- items = [item for item, _, _ in scored_items[: payload.limit]]
- now = datetime.utcnow()
- self.memory_repository.touch_many(memory_ids=[item.id for item in items], accessed_time=now)
- return scored_items[: payload.limit]
- def search_memories_contract(
- self,
- payload: MemorySearchRequestDto) -> list[tuple[MemoryItem, float, dict[str, float | str]]]:
- cached = self._read_search_cache(payload=payload)
- if cached is not None:
- self._touch_memory_access([item.id for item, _, _ in cached])
- return cached
- query_embedding = build_hash_embedding(
- payload.query,
- dimensions=self.settings.embedding_dimensions)
- candidates = self.memory_repository.search_candidates(
- scope_type=payload.scopeType,
- scope_id=payload.scopeId,
- owner_agent_id=payload.ownerAgentId,
- user_id=payload.userId,
- session_id=payload.sessionId,
- memory_type=payload.memoryType,
- limit=max(payload.limit * 10, payload.limit))
- scored_items = [
- self._score(item=item, query=payload.query, query_embedding=query_embedding)
- for item in candidates
- ]
- scored_items.sort(key=lambda item: item[1], reverse=True)
- results = scored_items[: payload.limit]
- self._write_search_cache(payload=payload, results=results)
- self._touch_memory_access([item.id for item, _, _ in results])
- return results
- def update_memory_status(
- self,
- *,
- memory_id: str,
- payload: MemoryStatusUpdateRequest) -> MemoryItem | None:
- entity = self.memory_repository.update_status(
- memory_id=memory_id,
- status=payload.status)
- if entity is not None:
- self._bump_search_cache_generation()
- return entity
- def get_memory(self, *, memory_id: str) -> MemoryItem | None:
- return self.memory_repository.get_by_id(memory_id=memory_id)
- def update_memory(self, payload: MemoryUpdateRequestDto) -> MemoryItem | None:
- embedding_json: list[float] | None = None
- embedding_model: str | None = None
- if payload.contentText is not None:
- embedding_json = build_hash_embedding(
- payload.contentText,
- dimensions=self.settings.embedding_dimensions)
- embedding_model = self.settings.embedding_model
- entity = self.memory_repository.update(
- memory_id=payload.memoryId,
- scope_type=payload.scopeType,
- scope_id=payload.scopeId,
- memory_type=payload.memoryType,
- content_text=payload.contentText,
- content_json=payload.content,
- metadata_json=payload.metadata,
- embedding_model=embedding_model,
- embedding_json=embedding_json,
- owner_agent_id=payload.ownerAgentId,
- user_id=payload.userId,
- session_id=payload.sessionId,
- source_ref=payload.sourceRef,
- importance_score=payload.importanceScore,
- expires_time=payload.expiresTime)
- if entity is not None:
- self._bump_search_cache_generation()
- return entity
- def delete_memory(self, *, memory_id: str) -> MemoryItem | None:
- entity = self.memory_repository.update_status(memory_id=memory_id, status="deleted")
- if entity is not None:
- self._bump_search_cache_generation()
- return entity
- def touch_memories(self, *, memory_ids: list[str], accessed_time: datetime | None = None) -> None:
- self.memory_repository.touch_many(
- memory_ids=memory_ids,
- accessed_time=accessed_time or datetime.utcnow())
- def _touch_memory_access(self, memory_ids: list[str]) -> None:
- if not memory_ids:
- return
- if (
- self.settings.async_touch_enabled
- and self.task_queue_publisher is not None
- and self.task_queue_publisher.publish_memory_touch(memory_ids=memory_ids)
- ):
- return
- self.touch_memories(memory_ids=memory_ids)
- def _read_search_cache(
- self,
- *,
- payload: MemorySearchRequestDto) -> list[tuple[MemoryItem, float, dict[str, float | str]]] | None:
- if self.redis_client is None or self.settings.search_cache_ttl_seconds <= 0:
- return None
- raw_value = self.redis_client.get(self._search_cache_key(payload=payload))
- if not isinstance(raw_value, (bytes, str)):
- return None
- decoded = raw_value.decode("utf-8") if isinstance(raw_value, bytes) else raw_value
- try:
- cached_items = json.loads(decoded)
- except json.JSONDecodeError:
- return None
- if not isinstance(cached_items, list):
- return None
- results: list[tuple[MemoryItem, float, dict[str, float | str]]] = []
- for cached_item in cached_items:
- if not isinstance(cached_item, dict):
- return None
- memory_id = cached_item.get("memoryId")
- score = cached_item.get("score")
- score_details = cached_item.get("scoreDetails")
- if not isinstance(memory_id, str) or not isinstance(score, (int, float)) or not isinstance(score_details, dict):
- return None
- item = self.memory_repository.get_by_id(memory_id=memory_id)
- if item is None or item.status != "active":
- return None
- results.append((
- item,
- float(score),
- {
- str(key): value
- for key, value in score_details.items()
- if isinstance(value, (int, float, str))
- }))
- return results
- def _write_search_cache(
- self,
- *,
- payload: MemorySearchRequestDto,
- results: list[tuple[MemoryItem, float, dict[str, float | str]]]) -> None:
- if self.redis_client is None or self.settings.search_cache_ttl_seconds <= 0:
- return
- cache_payload = [
- {
- "memoryId": item.id,
- "score": score,
- "scoreDetails": score_details,
- }
- for item, score, score_details in results
- ]
- self.redis_client.set(
- self._search_cache_key(payload=payload),
- json.dumps(cache_payload, ensure_ascii=False),
- ex=self.settings.search_cache_ttl_seconds)
- def _search_cache_key(self, *, payload: MemorySearchRequestDto) -> str:
- cache_payload = {
- **payload.model_dump(mode="json"),
- "generation": self._read_search_cache_generation(),
- "embeddingDimensions": self.settings.embedding_dimensions,
- "embeddingModel": self.settings.embedding_model,
- }
- digest = hashlib.sha256(
- json.dumps(cache_payload, sort_keys=True, ensure_ascii=False).encode("utf-8")
- ).hexdigest()
- return f"memory-search:{digest}"
- def _read_search_cache_generation(self) -> int:
- if self.redis_client is None:
- return 0
- value = self.redis_client.get("memory-search:generation")
- if isinstance(value, bytes):
- value = value.decode("utf-8")
- if isinstance(value, str) and value.isdigit():
- return int(value)
- return 0
- def _bump_search_cache_generation(self) -> None:
- if self.redis_client is None:
- return
- try:
- self.redis_client.incr("memory-search:generation")
- except Exception:
- return
- def _score(
- self,
- *,
- item: MemoryItem,
- query: str,
- query_embedding: list[float]) -> tuple[MemoryItem, float, dict[str, float | str]]:
- keyword = keyword_score(query, item.content_text)
- vector = cosine_similarity(query_embedding, item.embedding_json)
- score = rerank_score(
- keyword=keyword,
- vector=vector,
- importance_score=item.importance_score)
- return item, score, {
- "keyword_score": round(keyword, 6),
- "vector_score": round(vector, 6),
- "importance_score": float(item.importance_score),
- "embedding_model": item.embedding_model or self.settings.embedding_model,
- "rerank_mode": "hybrid-local",
- }
- def build_memory_application_service(
- *,
- db: Session,
- settings: MemoryServiceSettings) -> MemoryApplicationService:
- redis_client = try_build_redis_client(settings.redis_url)
- return MemoryApplicationService(
- memory_repository=MemoryItemRepository(db),
- settings=settings,
- redis_client=redis_client,
- task_queue_publisher=(
- TaskQueuePublisher(client=redis_client) if redis_client is not None else None
- ))
|