repositories.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from datetime import datetime
  2. from sqlalchemy import or_, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import MemoryScopeType, MemoryStatus
  5. from core_shared import JSONValue
  6. from app.db.models import MemoryItem
  7. class MemoryItemRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. scope_type: MemoryScopeType,
  15. scope_id: str,
  16. memory_type: str,
  17. content_text: str,
  18. content_json: dict[str, JSONValue] | None,
  19. metadata_json: dict[str, JSONValue],
  20. embedding_model: str | None,
  21. embedding_json: list[float] | None,
  22. owner_agent_id: str | None,
  23. user_id: str | None,
  24. session_id: str | None,
  25. source_ref: str | None,
  26. importance_score: int,
  27. expires_time: datetime | None,
  28. ) -> MemoryItem:
  29. entity = MemoryItem(
  30. tenant_id=tenant_id,
  31. scope_type=scope_type,
  32. scope_id=scope_id,
  33. memory_type=memory_type,
  34. content_text=content_text,
  35. content_json=content_json,
  36. metadata_json=metadata_json,
  37. embedding_model=embedding_model,
  38. embedding_json=embedding_json,
  39. owner_agent_id=owner_agent_id,
  40. user_id=user_id,
  41. session_id=session_id,
  42. source_ref=source_ref,
  43. importance_score=importance_score,
  44. status="active",
  45. expires_time=expires_time,
  46. )
  47. self.db.add(entity)
  48. self.db.commit()
  49. self.db.refresh(entity)
  50. return entity
  51. def list_by_scope(
  52. self,
  53. *,
  54. tenant_id: str,
  55. scope_type: MemoryScopeType | None = None,
  56. scope_id: str | None = None,
  57. status: MemoryStatus | None = "active",
  58. limit: int = 100,
  59. ) -> list[MemoryItem]:
  60. stmt = select(MemoryItem).where(MemoryItem.tenant_id == tenant_id)
  61. if scope_type is not None:
  62. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  63. if scope_id is not None:
  64. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  65. if status is not None:
  66. stmt = stmt.where(MemoryItem.status == status)
  67. stmt = stmt.order_by(MemoryItem.created_time.desc()).limit(limit)
  68. return list(self.db.scalars(stmt))
  69. def search_candidates(
  70. self,
  71. *,
  72. tenant_id: str,
  73. scope_type: MemoryScopeType | None,
  74. scope_id: str | None,
  75. owner_agent_id: str | None,
  76. user_id: str | None,
  77. session_id: str | None,
  78. limit: int,
  79. ) -> list[MemoryItem]:
  80. now = datetime.utcnow()
  81. stmt = (
  82. select(MemoryItem)
  83. .where(MemoryItem.tenant_id == tenant_id)
  84. .where(MemoryItem.status == "active")
  85. .where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now))
  86. )
  87. if scope_type is not None:
  88. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  89. if scope_id is not None:
  90. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  91. if owner_agent_id is not None:
  92. stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id)
  93. if user_id is not None:
  94. stmt = stmt.where(MemoryItem.user_id == user_id)
  95. if session_id is not None:
  96. stmt = stmt.where(MemoryItem.session_id == session_id)
  97. stmt = stmt.order_by(
  98. MemoryItem.importance_score.desc(),
  99. MemoryItem.created_time.desc(),
  100. ).limit(limit)
  101. return list(self.db.scalars(stmt))
  102. def get_by_id(self, *, tenant_id: str, memory_id: str) -> MemoryItem | None:
  103. stmt = (
  104. select(MemoryItem)
  105. .where(MemoryItem.tenant_id == tenant_id)
  106. .where(MemoryItem.id == memory_id)
  107. )
  108. return self.db.scalar(stmt)
  109. def touch_many(self, *, memory_ids: list[str], accessed_time: datetime) -> None:
  110. if not memory_ids:
  111. return
  112. items = list(self.db.scalars(select(MemoryItem).where(MemoryItem.id.in_(memory_ids))))
  113. for item in items:
  114. item.last_accessed_time = accessed_time
  115. self.db.commit()
  116. def update_status(
  117. self,
  118. *,
  119. tenant_id: str,
  120. memory_id: str,
  121. status: MemoryStatus,
  122. ) -> MemoryItem | None:
  123. entity = self.get_by_id(tenant_id=tenant_id, memory_id=memory_id)
  124. if entity is None:
  125. return None
  126. entity.status = status
  127. self.db.commit()
  128. self.db.refresh(entity)
  129. return entity