repositories.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from datetime import datetime
  2. from sqlalchemy import or_, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import MemoryScopeType, MemoryStatus
  5. from core_shared import JSONValue
  6. from app.db.models import MemoryItem
  7. class MemoryItemRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. scope_type: MemoryScopeType,
  15. scope_id: str,
  16. memory_type: str,
  17. content_text: str,
  18. content_json: dict[str, JSONValue] | None,
  19. metadata_json: dict[str, JSONValue],
  20. owner_agent_id: str | None,
  21. user_id: str | None,
  22. session_id: str | None,
  23. source_ref: str | None,
  24. importance_score: int,
  25. expires_time: datetime | None,
  26. ) -> MemoryItem:
  27. entity = MemoryItem(
  28. tenant_id=tenant_id,
  29. scope_type=scope_type,
  30. scope_id=scope_id,
  31. memory_type=memory_type,
  32. content_text=content_text,
  33. content_json=content_json,
  34. metadata_json=metadata_json,
  35. owner_agent_id=owner_agent_id,
  36. user_id=user_id,
  37. session_id=session_id,
  38. source_ref=source_ref,
  39. importance_score=importance_score,
  40. status="active",
  41. expires_time=expires_time,
  42. )
  43. self.db.add(entity)
  44. self.db.commit()
  45. self.db.refresh(entity)
  46. return entity
  47. def list_by_scope(
  48. self,
  49. *,
  50. tenant_id: str,
  51. scope_type: MemoryScopeType | None = None,
  52. scope_id: str | None = None,
  53. status: MemoryStatus | None = "active",
  54. limit: int = 100,
  55. ) -> list[MemoryItem]:
  56. stmt = select(MemoryItem).where(MemoryItem.tenant_id == tenant_id)
  57. if scope_type is not None:
  58. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  59. if scope_id is not None:
  60. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  61. if status is not None:
  62. stmt = stmt.where(MemoryItem.status == status)
  63. stmt = stmt.order_by(MemoryItem.created_time.desc()).limit(limit)
  64. return list(self.db.scalars(stmt))
  65. def search(
  66. self,
  67. *,
  68. tenant_id: str,
  69. query: str,
  70. scope_type: MemoryScopeType | None,
  71. scope_id: str | None,
  72. owner_agent_id: str | None,
  73. user_id: str | None,
  74. session_id: str | None,
  75. limit: int,
  76. ) -> list[MemoryItem]:
  77. now = datetime.utcnow()
  78. pattern = f"%{query}%"
  79. stmt = (
  80. select(MemoryItem)
  81. .where(MemoryItem.tenant_id == tenant_id)
  82. .where(MemoryItem.status == "active")
  83. .where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now))
  84. .where(MemoryItem.content_text.like(pattern))
  85. )
  86. if scope_type is not None:
  87. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  88. if scope_id is not None:
  89. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  90. if owner_agent_id is not None:
  91. stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id)
  92. if user_id is not None:
  93. stmt = stmt.where(MemoryItem.user_id == user_id)
  94. if session_id is not None:
  95. stmt = stmt.where(MemoryItem.session_id == session_id)
  96. stmt = stmt.order_by(MemoryItem.importance_score.desc(), MemoryItem.created_time.desc()).limit(
  97. limit
  98. )
  99. return list(self.db.scalars(stmt))
  100. def get_by_id(self, *, tenant_id: str, memory_id: str) -> MemoryItem | None:
  101. stmt = (
  102. select(MemoryItem)
  103. .where(MemoryItem.tenant_id == tenant_id)
  104. .where(MemoryItem.id == memory_id)
  105. )
  106. return self.db.scalar(stmt)
  107. def touch_many(self, *, memory_ids: list[str], accessed_time: datetime) -> None:
  108. if not memory_ids:
  109. return
  110. items = list(self.db.scalars(select(MemoryItem).where(MemoryItem.id.in_(memory_ids))))
  111. for item in items:
  112. item.last_accessed_time = accessed_time
  113. self.db.commit()
  114. def update_status(
  115. self,
  116. *,
  117. tenant_id: str,
  118. memory_id: str,
  119. status: MemoryStatus,
  120. ) -> MemoryItem | None:
  121. entity = self.get_by_id(tenant_id=tenant_id, memory_id=memory_id)
  122. if entity is None:
  123. return None
  124. entity.status = status
  125. self.db.commit()
  126. self.db.refresh(entity)
  127. return entity