repositories.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from datetime import datetime
  2. from core_domain import MemoryScopeType, MemoryStatus
  3. from core_shared import JSONValue
  4. from sqlalchemy import or_, select
  5. from sqlalchemy.orm import Session
  6. from app.db.models import MemoryItem
  7. class MemoryItemRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. scope_type: MemoryScopeType,
  14. scope_id: str,
  15. memory_type: str,
  16. content_text: str,
  17. content_json: dict[str, JSONValue] | None,
  18. metadata_json: dict[str, JSONValue],
  19. embedding_model: str | None,
  20. embedding_json: list[float] | None,
  21. owner_agent_id: str | None,
  22. user_id: str | None,
  23. session_id: str | None,
  24. source_ref: str | None,
  25. importance_score: int,
  26. expires_time: datetime | None) -> MemoryItem:
  27. entity = MemoryItem(
  28. scope_type=scope_type,
  29. scope_id=scope_id,
  30. memory_type=memory_type,
  31. content_text=content_text,
  32. content_json=content_json,
  33. metadata_json=metadata_json,
  34. embedding_model=embedding_model,
  35. embedding_json=embedding_json,
  36. owner_agent_id=owner_agent_id,
  37. user_id=user_id,
  38. session_id=session_id,
  39. source_ref=source_ref,
  40. importance_score=importance_score,
  41. status="active",
  42. expires_time=expires_time)
  43. self.db.add(entity)
  44. self.db.commit()
  45. self.db.refresh(entity)
  46. return entity
  47. def list_by_scope(
  48. self,
  49. *,
  50. scope_type: MemoryScopeType | None = None,
  51. scope_id: str | None = None,
  52. status: MemoryStatus | None = "active",
  53. limit: int = 100) -> list[MemoryItem]:
  54. stmt = select(MemoryItem)
  55. if scope_type is not None:
  56. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  57. if scope_id is not None:
  58. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  59. if status is not None:
  60. stmt = stmt.where(MemoryItem.status == status)
  61. stmt = stmt.order_by(MemoryItem.created_time.desc()).limit(limit)
  62. return list(self.db.scalars(stmt))
  63. def search_candidates(
  64. self,
  65. *,
  66. scope_type: MemoryScopeType | None,
  67. scope_id: str | None,
  68. owner_agent_id: str | None,
  69. user_id: str | None,
  70. session_id: str | None,
  71. limit: int) -> list[MemoryItem]:
  72. now = datetime.utcnow()
  73. stmt = (
  74. select(MemoryItem)
  75. .where(MemoryItem.status == "active")
  76. .where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now))
  77. )
  78. if scope_type is not None:
  79. stmt = stmt.where(MemoryItem.scope_type == scope_type)
  80. if scope_id is not None:
  81. stmt = stmt.where(MemoryItem.scope_id == scope_id)
  82. if owner_agent_id is not None:
  83. stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id)
  84. if user_id is not None:
  85. stmt = stmt.where(MemoryItem.user_id == user_id)
  86. if session_id is not None:
  87. stmt = stmt.where(MemoryItem.session_id == session_id)
  88. stmt = stmt.order_by(
  89. MemoryItem.importance_score.desc(),
  90. MemoryItem.created_time.desc()).limit(limit)
  91. return list(self.db.scalars(stmt))
  92. def get_by_id(self, *, memory_id: str) -> MemoryItem | None:
  93. stmt = (
  94. select(MemoryItem)
  95. .where(MemoryItem.id == memory_id)
  96. )
  97. return self.db.scalar(stmt)
  98. def touch_many(self, *, memory_ids: list[str], accessed_time: datetime) -> None:
  99. if not memory_ids:
  100. return
  101. items = list(self.db.scalars(select(MemoryItem).where(MemoryItem.id.in_(memory_ids))))
  102. for item in items:
  103. item.last_accessed_time = accessed_time
  104. self.db.commit()
  105. def update_status(
  106. self,
  107. *,
  108. memory_id: str,
  109. status: MemoryStatus) -> MemoryItem | None:
  110. entity = self.get_by_id(memory_id=memory_id)
  111. if entity is None:
  112. return None
  113. entity.status = status
  114. self.db.commit()
  115. self.db.refresh(entity)
  116. return entity