| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- from datetime import datetime
- from sqlalchemy import or_, select
- from sqlalchemy.orm import Session
- from core_domain import MemoryScopeType, MemoryStatus
- from core_shared import JSONValue
- from app.db.models import MemoryItem
- class MemoryItemRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- scope_type: MemoryScopeType,
- scope_id: str,
- memory_type: str,
- content_text: str,
- content_json: dict[str, JSONValue] | None,
- metadata_json: dict[str, JSONValue],
- owner_agent_id: str | None,
- user_id: str | None,
- session_id: str | None,
- source_ref: str | None,
- importance_score: int,
- expires_time: datetime | None,
- ) -> MemoryItem:
- entity = MemoryItem(
- tenant_id=tenant_id,
- scope_type=scope_type,
- scope_id=scope_id,
- memory_type=memory_type,
- content_text=content_text,
- content_json=content_json,
- metadata_json=metadata_json,
- owner_agent_id=owner_agent_id,
- user_id=user_id,
- session_id=session_id,
- source_ref=source_ref,
- importance_score=importance_score,
- status="active",
- expires_time=expires_time,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(
- self,
- *,
- tenant_id: str,
- scope_type: MemoryScopeType | None = None,
- scope_id: str | None = None,
- status: MemoryStatus | None = "active",
- limit: int = 100,
- ) -> list[MemoryItem]:
- stmt = select(MemoryItem).where(MemoryItem.tenant_id == tenant_id)
- if scope_type is not None:
- stmt = stmt.where(MemoryItem.scope_type == scope_type)
- if scope_id is not None:
- stmt = stmt.where(MemoryItem.scope_id == scope_id)
- if status is not None:
- stmt = stmt.where(MemoryItem.status == status)
- stmt = stmt.order_by(MemoryItem.created_time.desc()).limit(limit)
- return list(self.db.scalars(stmt))
- def search(
- self,
- *,
- tenant_id: str,
- query: str,
- scope_type: MemoryScopeType | None,
- scope_id: str | None,
- owner_agent_id: str | None,
- user_id: str | None,
- session_id: str | None,
- limit: int,
- ) -> list[MemoryItem]:
- now = datetime.utcnow()
- pattern = f"%{query}%"
- stmt = (
- select(MemoryItem)
- .where(MemoryItem.tenant_id == tenant_id)
- .where(MemoryItem.status == "active")
- .where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now))
- .where(MemoryItem.content_text.like(pattern))
- )
- if scope_type is not None:
- stmt = stmt.where(MemoryItem.scope_type == scope_type)
- if scope_id is not None:
- stmt = stmt.where(MemoryItem.scope_id == scope_id)
- if owner_agent_id is not None:
- stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id)
- if user_id is not None:
- stmt = stmt.where(MemoryItem.user_id == user_id)
- if session_id is not None:
- stmt = stmt.where(MemoryItem.session_id == session_id)
- stmt = stmt.order_by(MemoryItem.importance_score.desc(), MemoryItem.created_time.desc()).limit(
- limit
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, *, tenant_id: str, memory_id: str) -> MemoryItem | None:
- stmt = (
- select(MemoryItem)
- .where(MemoryItem.tenant_id == tenant_id)
- .where(MemoryItem.id == memory_id)
- )
- return self.db.scalar(stmt)
- def touch_many(self, *, memory_ids: list[str], accessed_time: datetime) -> None:
- if not memory_ids:
- return
- items = list(self.db.scalars(select(MemoryItem).where(MemoryItem.id.in_(memory_ids))))
- for item in items:
- item.last_accessed_time = accessed_time
- self.db.commit()
- def update_status(
- self,
- *,
- tenant_id: str,
- memory_id: str,
- status: MemoryStatus,
- ) -> MemoryItem | None:
- entity = self.get_by_id(tenant_id=tenant_id, memory_id=memory_id)
- if entity is None:
- return None
- entity.status = status
- self.db.commit()
- self.db.refresh(entity)
- return entity
|