services.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. from __future__ import annotations
  2. import hashlib
  3. import json
  4. from datetime import datetime
  5. from typing import TYPE_CHECKING
  6. from core_domain import MemoryScopeType, MemoryStatus
  7. from sqlalchemy.orm import Session
  8. from core_shared import JSONValue, try_build_redis_client
  9. from core_shared.task_queue import TaskQueuePublisher
  10. from app.application.retrieval import (
  11. build_hash_embedding,
  12. cosine_similarity,
  13. keyword_score,
  14. rerank_score,
  15. )
  16. from app.bootstrap.settings import MemoryServiceSettings
  17. from app.db.models import MemoryItem
  18. from app.domain.repositories import MemoryItemRepository
  19. from app.schemas.memory import (
  20. MemoryCreateRequest,
  21. MemoryCreateRequestDto,
  22. MemoryListRequestDto,
  23. MemorySearchRequest,
  24. MemorySearchRequestDto,
  25. MemoryStatusUpdateRequest,
  26. MemoryUpdateRequestDto,
  27. )
  28. if TYPE_CHECKING:
  29. from redis import Redis
  30. class MemoryApplicationService:
  31. def __init__(
  32. self,
  33. *,
  34. memory_repository: MemoryItemRepository,
  35. settings: MemoryServiceSettings | None = None,
  36. redis_client: Redis | None = None,
  37. task_queue_publisher: TaskQueuePublisher | None = None) -> None:
  38. self.memory_repository = memory_repository
  39. self.settings = settings or MemoryServiceSettings()
  40. self.redis_client = redis_client
  41. self.task_queue_publisher = task_queue_publisher
  42. def create_memory(self, payload: MemoryCreateRequest) -> MemoryItem:
  43. embedding_json = build_hash_embedding(
  44. payload.content_text,
  45. dimensions=self.settings.embedding_dimensions)
  46. entity = self.memory_repository.create(
  47. scope_type=payload.scope_type,
  48. scope_id=payload.scope_id,
  49. memory_type=payload.memory_type,
  50. content_text=payload.content_text,
  51. content_json=payload.content_json,
  52. metadata_json=payload.metadata_json,
  53. embedding_model=self.settings.embedding_model,
  54. embedding_json=embedding_json,
  55. owner_agent_id=payload.owner_agent_id,
  56. user_id=payload.user_id,
  57. session_id=payload.session_id,
  58. source_ref=payload.source_ref,
  59. importance_score=payload.importance_score,
  60. expires_time=payload.expires_time)
  61. self._bump_search_cache_generation()
  62. return entity
  63. def create_memory_from_contract(self, payload: MemoryCreateRequestDto) -> MemoryItem:
  64. return self.create_memory(MemoryCreateRequest(
  65. scope_type=payload.scopeType,
  66. scope_id=payload.scopeId,
  67. memory_type=payload.memoryType,
  68. content_text=payload.contentText,
  69. content_json=payload.content,
  70. metadata_json=payload.metadata,
  71. owner_agent_id=payload.ownerAgentId,
  72. user_id=payload.userId,
  73. session_id=payload.sessionId,
  74. source_ref=payload.sourceRef,
  75. importance_score=payload.importanceScore,
  76. expires_time=payload.expiresTime))
  77. def list_memories(
  78. self,
  79. *,
  80. scope_type: MemoryScopeType | None = None,
  81. scope_id: str | None = None,
  82. status: MemoryStatus | None = "active",
  83. limit: int = 100) -> list[MemoryItem]:
  84. return self.memory_repository.list_by_scope(
  85. scope_type=scope_type,
  86. scope_id=scope_id,
  87. status=status,
  88. limit=limit)
  89. def list_memories_contract(self, payload: MemoryListRequestDto) -> tuple[list[MemoryItem], int]:
  90. return self.memory_repository.list_filtered(
  91. scope_type=payload.scopeType,
  92. scope_id=payload.scopeId,
  93. memory_type=payload.memoryType,
  94. status=payload.status,
  95. owner_agent_id=payload.ownerAgentId,
  96. user_id=payload.userId,
  97. session_id=payload.sessionId,
  98. keyword=payload.keyword,
  99. include_expired=payload.includeExpired,
  100. offset=payload.offset,
  101. limit=payload.pageSize)
  102. def search_memories(
  103. self,
  104. payload: MemorySearchRequest) -> list[tuple[MemoryItem, float, dict[str, float | str]]]:
  105. query_embedding = build_hash_embedding(
  106. payload.query,
  107. dimensions=self.settings.embedding_dimensions)
  108. candidates = self.memory_repository.search_candidates(
  109. scope_type=payload.scope_type,
  110. scope_id=payload.scope_id,
  111. owner_agent_id=payload.owner_agent_id,
  112. user_id=payload.user_id,
  113. session_id=payload.session_id,
  114. limit=max(payload.limit * 10, payload.limit))
  115. scored_items = [
  116. self._score(item=item, query=payload.query, query_embedding=query_embedding)
  117. for item in candidates
  118. ]
  119. scored_items.sort(key=lambda item: item[1], reverse=True)
  120. items = [item for item, _, _ in scored_items[: payload.limit]]
  121. now = datetime.utcnow()
  122. self.memory_repository.touch_many(memory_ids=[item.id for item in items], accessed_time=now)
  123. return scored_items[: payload.limit]
  124. def search_memories_contract(
  125. self,
  126. payload: MemorySearchRequestDto) -> list[tuple[MemoryItem, float, dict[str, float | str]]]:
  127. cached = self._read_search_cache(payload=payload)
  128. if cached is not None:
  129. self._touch_memory_access([item.id for item, _, _ in cached])
  130. return cached
  131. query_embedding = build_hash_embedding(
  132. payload.query,
  133. dimensions=self.settings.embedding_dimensions)
  134. candidates = self.memory_repository.search_candidates(
  135. scope_type=payload.scopeType,
  136. scope_id=payload.scopeId,
  137. owner_agent_id=payload.ownerAgentId,
  138. user_id=payload.userId,
  139. session_id=payload.sessionId,
  140. memory_type=payload.memoryType,
  141. limit=max(payload.limit * 10, payload.limit))
  142. scored_items = [
  143. self._score(item=item, query=payload.query, query_embedding=query_embedding)
  144. for item in candidates
  145. ]
  146. scored_items.sort(key=lambda item: item[1], reverse=True)
  147. results = scored_items[: payload.limit]
  148. self._write_search_cache(payload=payload, results=results)
  149. self._touch_memory_access([item.id for item, _, _ in results])
  150. return results
  151. def update_memory_status(
  152. self,
  153. *,
  154. memory_id: str,
  155. payload: MemoryStatusUpdateRequest) -> MemoryItem | None:
  156. entity = self.memory_repository.update_status(
  157. memory_id=memory_id,
  158. status=payload.status)
  159. if entity is not None:
  160. self._bump_search_cache_generation()
  161. return entity
  162. def get_memory(self, *, memory_id: str) -> MemoryItem | None:
  163. return self.memory_repository.get_by_id(memory_id=memory_id)
  164. def update_memory(self, payload: MemoryUpdateRequestDto) -> MemoryItem | None:
  165. embedding_json: list[float] | None = None
  166. embedding_model: str | None = None
  167. if payload.contentText is not None:
  168. embedding_json = build_hash_embedding(
  169. payload.contentText,
  170. dimensions=self.settings.embedding_dimensions)
  171. embedding_model = self.settings.embedding_model
  172. entity = self.memory_repository.update(
  173. memory_id=payload.memoryId,
  174. scope_type=payload.scopeType,
  175. scope_id=payload.scopeId,
  176. memory_type=payload.memoryType,
  177. content_text=payload.contentText,
  178. content_json=payload.content,
  179. metadata_json=payload.metadata,
  180. embedding_model=embedding_model,
  181. embedding_json=embedding_json,
  182. owner_agent_id=payload.ownerAgentId,
  183. user_id=payload.userId,
  184. session_id=payload.sessionId,
  185. source_ref=payload.sourceRef,
  186. importance_score=payload.importanceScore,
  187. expires_time=payload.expiresTime)
  188. if entity is not None:
  189. self._bump_search_cache_generation()
  190. return entity
  191. def delete_memory(self, *, memory_id: str) -> MemoryItem | None:
  192. entity = self.memory_repository.update_status(memory_id=memory_id, status="deleted")
  193. if entity is not None:
  194. self._bump_search_cache_generation()
  195. return entity
  196. def touch_memories(self, *, memory_ids: list[str], accessed_time: datetime | None = None) -> None:
  197. self.memory_repository.touch_many(
  198. memory_ids=memory_ids,
  199. accessed_time=accessed_time or datetime.utcnow())
  200. def _touch_memory_access(self, memory_ids: list[str]) -> None:
  201. if not memory_ids:
  202. return
  203. if (
  204. self.settings.async_touch_enabled
  205. and self.task_queue_publisher is not None
  206. and self.task_queue_publisher.publish_memory_touch(memory_ids=memory_ids)
  207. ):
  208. return
  209. self.touch_memories(memory_ids=memory_ids)
  210. def _read_search_cache(
  211. self,
  212. *,
  213. payload: MemorySearchRequestDto) -> list[tuple[MemoryItem, float, dict[str, float | str]]] | None:
  214. if self.redis_client is None or self.settings.search_cache_ttl_seconds <= 0:
  215. return None
  216. raw_value = self.redis_client.get(self._search_cache_key(payload=payload))
  217. if not isinstance(raw_value, (bytes, str)):
  218. return None
  219. decoded = raw_value.decode("utf-8") if isinstance(raw_value, bytes) else raw_value
  220. try:
  221. cached_items = json.loads(decoded)
  222. except json.JSONDecodeError:
  223. return None
  224. if not isinstance(cached_items, list):
  225. return None
  226. results: list[tuple[MemoryItem, float, dict[str, float | str]]] = []
  227. for cached_item in cached_items:
  228. if not isinstance(cached_item, dict):
  229. return None
  230. memory_id = cached_item.get("memoryId")
  231. score = cached_item.get("score")
  232. score_details = cached_item.get("scoreDetails")
  233. if not isinstance(memory_id, str) or not isinstance(score, (int, float)) or not isinstance(score_details, dict):
  234. return None
  235. item = self.memory_repository.get_by_id(memory_id=memory_id)
  236. if item is None or item.status != "active":
  237. return None
  238. results.append((
  239. item,
  240. float(score),
  241. {
  242. str(key): value
  243. for key, value in score_details.items()
  244. if isinstance(value, (int, float, str))
  245. }))
  246. return results
  247. def _write_search_cache(
  248. self,
  249. *,
  250. payload: MemorySearchRequestDto,
  251. results: list[tuple[MemoryItem, float, dict[str, float | str]]]) -> None:
  252. if self.redis_client is None or self.settings.search_cache_ttl_seconds <= 0:
  253. return
  254. cache_payload = [
  255. {
  256. "memoryId": item.id,
  257. "score": score,
  258. "scoreDetails": score_details,
  259. }
  260. for item, score, score_details in results
  261. ]
  262. self.redis_client.set(
  263. self._search_cache_key(payload=payload),
  264. json.dumps(cache_payload, ensure_ascii=False),
  265. ex=self.settings.search_cache_ttl_seconds)
  266. def _search_cache_key(self, *, payload: MemorySearchRequestDto) -> str:
  267. cache_payload = {
  268. **payload.model_dump(mode="json"),
  269. "generation": self._read_search_cache_generation(),
  270. "embeddingDimensions": self.settings.embedding_dimensions,
  271. "embeddingModel": self.settings.embedding_model,
  272. }
  273. digest = hashlib.sha256(
  274. json.dumps(cache_payload, sort_keys=True, ensure_ascii=False).encode("utf-8")
  275. ).hexdigest()
  276. return f"memory-search:{digest}"
  277. def _read_search_cache_generation(self) -> int:
  278. if self.redis_client is None:
  279. return 0
  280. value = self.redis_client.get("memory-search:generation")
  281. if isinstance(value, bytes):
  282. value = value.decode("utf-8")
  283. if isinstance(value, str) and value.isdigit():
  284. return int(value)
  285. return 0
  286. def _bump_search_cache_generation(self) -> None:
  287. if self.redis_client is None:
  288. return
  289. try:
  290. self.redis_client.incr("memory-search:generation")
  291. except Exception:
  292. return
  293. def _score(
  294. self,
  295. *,
  296. item: MemoryItem,
  297. query: str,
  298. query_embedding: list[float]) -> tuple[MemoryItem, float, dict[str, float | str]]:
  299. keyword = keyword_score(query, item.content_text)
  300. vector = cosine_similarity(query_embedding, item.embedding_json)
  301. score = rerank_score(
  302. keyword=keyword,
  303. vector=vector,
  304. importance_score=item.importance_score)
  305. return item, score, {
  306. "keyword_score": round(keyword, 6),
  307. "vector_score": round(vector, 6),
  308. "importance_score": float(item.importance_score),
  309. "embedding_model": item.embedding_model or self.settings.embedding_model,
  310. "rerank_mode": "hybrid-local",
  311. }
  312. def build_memory_application_service(
  313. *,
  314. db: Session,
  315. settings: MemoryServiceSettings) -> MemoryApplicationService:
  316. redis_client = try_build_redis_client(settings.redis_url)
  317. return MemoryApplicationService(
  318. memory_repository=MemoryItemRepository(db),
  319. settings=settings,
  320. redis_client=redis_client,
  321. task_queue_publisher=(
  322. TaskQueuePublisher(client=redis_client) if redis_client is not None else None
  323. ))