from datetime import datetime from sqlalchemy import or_, select from sqlalchemy.orm import Session from core_domain import MemoryScopeType, MemoryStatus from core_shared import JSONValue from app.db.models import MemoryItem class MemoryItemRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, scope_type: MemoryScopeType, scope_id: str, memory_type: str, content_text: str, content_json: dict[str, JSONValue] | None, metadata_json: dict[str, JSONValue], embedding_model: str | None, embedding_json: list[float] | None, owner_agent_id: str | None, user_id: str | None, session_id: str | None, source_ref: str | None, importance_score: int, expires_time: datetime | None, ) -> MemoryItem: entity = MemoryItem( tenant_id=tenant_id, scope_type=scope_type, scope_id=scope_id, memory_type=memory_type, content_text=content_text, content_json=content_json, metadata_json=metadata_json, embedding_model=embedding_model, embedding_json=embedding_json, owner_agent_id=owner_agent_id, user_id=user_id, session_id=session_id, source_ref=source_ref, importance_score=importance_score, status="active", expires_time=expires_time, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, tenant_id: str, scope_type: MemoryScopeType | None = None, scope_id: str | None = None, status: MemoryStatus | None = "active", limit: int = 100, ) -> list[MemoryItem]: stmt = select(MemoryItem).where(MemoryItem.tenant_id == tenant_id) if scope_type is not None: stmt = stmt.where(MemoryItem.scope_type == scope_type) if scope_id is not None: stmt = stmt.where(MemoryItem.scope_id == scope_id) if status is not None: stmt = stmt.where(MemoryItem.status == status) stmt = stmt.order_by(MemoryItem.created_time.desc()).limit(limit) return list(self.db.scalars(stmt)) def search_candidates( self, *, tenant_id: str, scope_type: MemoryScopeType | None, scope_id: str | None, owner_agent_id: str | None, user_id: str | None, session_id: str | None, limit: int, ) -> list[MemoryItem]: now = datetime.utcnow() stmt = ( select(MemoryItem) .where(MemoryItem.tenant_id == tenant_id) .where(MemoryItem.status == "active") .where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now)) ) if scope_type is not None: stmt = stmt.where(MemoryItem.scope_type == scope_type) if scope_id is not None: stmt = stmt.where(MemoryItem.scope_id == scope_id) if owner_agent_id is not None: stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id) if user_id is not None: stmt = stmt.where(MemoryItem.user_id == user_id) if session_id is not None: stmt = stmt.where(MemoryItem.session_id == session_id) stmt = stmt.order_by( MemoryItem.importance_score.desc(), MemoryItem.created_time.desc(), ).limit(limit) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, memory_id: str) -> MemoryItem | None: stmt = ( select(MemoryItem) .where(MemoryItem.tenant_id == tenant_id) .where(MemoryItem.id == memory_id) ) return self.db.scalar(stmt) def touch_many(self, *, memory_ids: list[str], accessed_time: datetime) -> None: if not memory_ids: return items = list(self.db.scalars(select(MemoryItem).where(MemoryItem.id.in_(memory_ids)))) for item in items: item.last_accessed_time = accessed_time self.db.commit() def update_status( self, *, tenant_id: str, memory_id: str, status: MemoryStatus, ) -> MemoryItem | None: entity = self.get_by_id(tenant_id=tenant_id, memory_id=memory_id) if entity is None: return None entity.status = status self.db.commit() self.db.refresh(entity) return entity