repositories.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from datetime import datetime
  2. from core_domain import MemoryScopeType, MemoryStatus
  3. from core_shared import JSONValue
  4. from sqlalchemy import func, or_, select
  5. from sqlalchemy.orm import Session
  6. from app.db.models import MemoryItem
  7. class MemoryItemRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. scope_type: MemoryScopeType,
  14. scope_id: str,
  15. memory_type: str,
  16. content_text: str,
  17. content_json: dict[str, JSONValue] | None,
  18. metadata_json: dict[str, JSONValue],
  19. embedding_model: str | None,
  20. embedding_json: list[float] | None,
  21. owner_agent_id: str | None,
  22. user_id: str | None,
  23. session_id: str | None,
  24. source_ref: str | None,
  25. importance_score: int,
  26. expires_time: datetime | None) -> MemoryItem:
  27. entity = MemoryItem(
  28. scope_type=scope_type,
  29. scope_id=scope_id,
  30. memory_type=memory_type,
  31. content_text=content_text,
  32. content_json=content_json,
  33. metadata_json=metadata_json,
  34. embedding_model=embedding_model,
  35. embedding_json=embedding_json,
  36. owner_agent_id=owner_agent_id,
  37. user_id=user_id,
  38. session_id=session_id,
  39. source_ref=source_ref,
  40. importance_score=importance_score,
  41. status="active",
  42. expires_time=expires_time)
  43. self.db.add(entity)
  44. self.db.commit()
  45. self.db.refresh(entity)
  46. return entity
  47. def list_by_scope(
  48. self,
  49. *,
  50. scope_type: MemoryScopeType | None = None,
  51. scope_id: str | None = None,
  52. status: MemoryStatus | None = "active",
  53. limit: int = 100) -> list[MemoryItem]:
  54. stmt = select(MemoryItem)
  55. if scope_type is not None:
  56. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  57. if scope_id is not None:
  58. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  59. if status is not None:
  60. stmt = stmt.where(MemoryItem.status == status)
  61. stmt = stmt.order_by(MemoryItem.created_time.desc()).limit(limit)
  62. return list(self.db.scalars(stmt))
  63. def list_filtered(
  64. self,
  65. *,
  66. scope_type: MemoryScopeType | None = None,
  67. scope_id: str | None = None,
  68. memory_type: str | None = None,
  69. status: MemoryStatus | None = "active",
  70. owner_agent_id: str | None = None,
  71. user_id: str | None = None,
  72. session_id: str | None = None,
  73. keyword: str | None = None,
  74. include_expired: bool = True,
  75. offset: int = 0,
  76. limit: int = 20) -> tuple[list[MemoryItem], int]:
  77. stmt = select(MemoryItem)
  78. now = datetime.utcnow()
  79. if scope_type is not None:
  80. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  81. if scope_id is not None:
  82. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  83. if memory_type is not None:
  84. stmt = stmt.where(MemoryItem.memory_type == memory_type)
  85. if status is not None:
  86. stmt = stmt.where(MemoryItem.status == status)
  87. if owner_agent_id is not None:
  88. stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id)
  89. if user_id is not None:
  90. stmt = stmt.where(MemoryItem.user_id == user_id)
  91. if session_id is not None:
  92. stmt = stmt.where(MemoryItem.session_id == session_id)
  93. if not include_expired:
  94. stmt = stmt.where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now))
  95. if keyword:
  96. pattern = f"%{keyword.strip()}%"
  97. stmt = stmt.where(or_(
  98. MemoryItem.content_text.ilike(pattern),
  99. MemoryItem.memory_type.ilike(pattern),
  100. MemoryItem.scope_id.ilike(pattern),
  101. MemoryItem.source_ref.ilike(pattern)))
  102. total = self.db.scalar(select(func.count()).select_from(stmt.subquery())) or 0
  103. items = list(self.db.scalars(
  104. stmt.order_by(MemoryItem.created_time.desc()).offset(offset).limit(limit)))
  105. return items, total
  106. def search_candidates(
  107. self,
  108. *,
  109. scope_type: MemoryScopeType | None,
  110. scope_id: str | None,
  111. owner_agent_id: str | None,
  112. user_id: str | None,
  113. session_id: str | None,
  114. limit: int,
  115. memory_type: str | None = None) -> list[MemoryItem]:
  116. now = datetime.utcnow()
  117. stmt = (
  118. select(MemoryItem)
  119. .where(MemoryItem.status == "active")
  120. .where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now))
  121. )
  122. if scope_type is not None:
  123. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  124. if scope_id is not None:
  125. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  126. if owner_agent_id is not None:
  127. stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id)
  128. if user_id is not None:
  129. stmt = stmt.where(MemoryItem.user_id == user_id)
  130. if session_id is not None:
  131. stmt = stmt.where(MemoryItem.session_id == session_id)
  132. if memory_type is not None:
  133. stmt = stmt.where(MemoryItem.memory_type == memory_type)
  134. stmt = stmt.order_by(
  135. MemoryItem.importance_score.desc(),
  136. MemoryItem.created_time.desc()).limit(limit)
  137. return list(self.db.scalars(stmt))
  138. def get_by_id(self, *, memory_id: str) -> MemoryItem | None:
  139. stmt = (
  140. select(MemoryItem)
  141. .where(MemoryItem.id == memory_id)
  142. )
  143. return self.db.scalar(stmt)
  144. def touch_many(self, *, memory_ids: list[str], accessed_time: datetime) -> None:
  145. if not memory_ids:
  146. return
  147. items = list(self.db.scalars(select(MemoryItem).where(MemoryItem.id.in_(memory_ids))))
  148. for item in items:
  149. item.last_accessed_time = accessed_time
  150. self.db.commit()
  151. def update_status(
  152. self,
  153. *,
  154. memory_id: str,
  155. status: MemoryStatus) -> MemoryItem | None:
  156. entity = self.get_by_id(memory_id=memory_id)
  157. if entity is None:
  158. return None
  159. entity.status = status
  160. self.db.commit()
  161. self.db.refresh(entity)
  162. return entity
  163. def update(
  164. self,
  165. *,
  166. memory_id: str,
  167. scope_type: MemoryScopeType | None = None,
  168. scope_id: str | None = None,
  169. memory_type: str | None = None,
  170. content_text: str | None = None,
  171. content_json: dict[str, JSONValue] | None = None,
  172. metadata_json: dict[str, JSONValue] | None = None,
  173. embedding_model: str | None = None,
  174. embedding_json: list[float] | None = None,
  175. owner_agent_id: str | None = None,
  176. user_id: str | None = None,
  177. session_id: str | None = None,
  178. source_ref: str | None = None,
  179. importance_score: int | None = None,
  180. expires_time: datetime | None = None) -> MemoryItem | None:
  181. entity = self.get_by_id(memory_id=memory_id)
  182. if entity is None:
  183. return None
  184. if scope_type is not None:
  185. entity.scope_type = scope_type
  186. if scope_id is not None:
  187. entity.scope_id = scope_id
  188. if memory_type is not None:
  189. entity.memory_type = memory_type
  190. if content_text is not None:
  191. entity.content_text = content_text
  192. if content_json is not None:
  193. entity.content_json = content_json
  194. if metadata_json is not None:
  195. entity.metadata_json = metadata_json
  196. if embedding_model is not None:
  197. entity.embedding_model = embedding_model
  198. if embedding_json is not None:
  199. entity.embedding_json = embedding_json
  200. if owner_agent_id is not None:
  201. entity.owner_agent_id = owner_agent_id
  202. if user_id is not None:
  203. entity.user_id = user_id
  204. if session_id is not None:
  205. entity.session_id = session_id
  206. if source_ref is not None:
  207. entity.source_ref = source_ref
  208. if importance_score is not None:
  209. entity.importance_score = importance_score
  210. if expires_time is not None:
  211. entity.expires_time = expires_time
  212. self.db.commit()
  213. self.db.refresh(entity)
  214. return entity