memory_client.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import httpx
  2. from core_domain import (
  3. MemoryCreateContract,
  4. MemoryItemContract,
  5. MemorySearchRequestContract,
  6. MemorySearchResultContract,
  7. )
  8. class MemoryClientError(Exception):
  9. pass
  10. class MemoryClient:
  11. def __init__(self, *, base_url: str, timeout_seconds: float = 10.0) -> None:
  12. self.base_url = base_url.rstrip("/")
  13. self.timeout_seconds = timeout_seconds
  14. def create_memory(self, payload: MemoryCreateContract) -> MemoryItemContract:
  15. try:
  16. with httpx.Client(timeout=self.timeout_seconds) as client:
  17. response = client.post(
  18. f"{self.base_url}/memories/create",
  19. json=_create_payload_to_contract(payload))
  20. response.raise_for_status()
  21. return MemoryItemContract.model_validate(_memory_dto_to_contract(_unwrap(response.json())))
  22. except httpx.HTTPError as exc:
  23. raise MemoryClientError(f"memory-service create request failed: {exc}") from exc
  24. def search_memories(
  25. self,
  26. payload: MemorySearchRequestContract) -> list[MemorySearchResultContract]:
  27. try:
  28. with httpx.Client(timeout=self.timeout_seconds) as client:
  29. response = client.post(
  30. f"{self.base_url}/memories/search/query",
  31. json=_search_payload_to_contract(payload))
  32. response.raise_for_status()
  33. return [
  34. MemorySearchResultContract.model_validate({
  35. "item": _memory_dto_to_contract(item["item"]),
  36. "score": item["score"],
  37. "score_json": item.get("scoreDetails", {}),
  38. })
  39. for item in _unwrap(response.json())
  40. ]
  41. except httpx.HTTPError as exc:
  42. raise MemoryClientError(f"memory-service search request failed: {exc}") from exc
  43. def _unwrap(payload: dict) -> object:
  44. if not payload.get("success", False):
  45. message = payload.get("error", {}).get("message", "memory-service request failed")
  46. raise MemoryClientError(str(message))
  47. return payload.get("data")
  48. def _create_payload_to_contract(payload: MemoryCreateContract) -> dict:
  49. data = payload.model_dump(mode="json")
  50. return {
  51. "scopeType": data["scope_type"],
  52. "scopeId": data["scope_id"],
  53. "memoryType": data.get("memory_type", "fact"),
  54. "contentText": data["content_text"],
  55. "content": data.get("content_json"),
  56. "metadata": data.get("metadata_json", {}),
  57. "ownerAgentId": data.get("owner_agent_id"),
  58. "userId": data.get("user_id"),
  59. "sessionId": data.get("session_id"),
  60. "sourceRef": data.get("source_ref"),
  61. "importanceScore": data.get("importance_score", 0),
  62. "expiresTime": data.get("expires_time"),
  63. }
  64. def _search_payload_to_contract(payload: MemorySearchRequestContract) -> dict:
  65. data = payload.model_dump(mode="json")
  66. return {
  67. "query": data["query"],
  68. "scopeType": data.get("scope_type"),
  69. "scopeId": data.get("scope_id"),
  70. "ownerAgentId": data.get("owner_agent_id"),
  71. "userId": data.get("user_id"),
  72. "sessionId": data.get("session_id"),
  73. "limit": data.get("limit", 8),
  74. }
  75. def _memory_dto_to_contract(item: object) -> dict:
  76. if not isinstance(item, dict):
  77. raise MemoryClientError("invalid memory-service response")
  78. return {
  79. "id": item["id"],
  80. "scope_type": item["scopeType"],
  81. "scope_id": item["scopeId"],
  82. "memory_type": item["memoryType"],
  83. "content_text": item["contentText"],
  84. "content_json": item.get("content"),
  85. "metadata_json": item.get("metadata", {}),
  86. "embedding_model": item.get("embeddingModel"),
  87. "embedding_json": item.get("embedding"),
  88. "owner_agent_id": item.get("ownerAgentId"),
  89. "user_id": item.get("userId"),
  90. "session_id": item.get("sessionId"),
  91. "source_ref": item.get("sourceRef"),
  92. "importance_score": item.get("importanceScore", 0),
  93. "status": item["status"],
  94. "last_accessed_time": item.get("lastAccessedTime"),
  95. "expires_time": item.get("expiresTime"),
  96. "created_time": item["createdTime"],
  97. }