from datetime import datetime from core_domain import MemoryScopeType, MemoryStatus from core_shared import JSONValue from sqlalchemy import func, or_, select from sqlalchemy.orm import Session from app.db.models import MemoryItem class MemoryItemRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, scope_type: MemoryScopeType, scope_id: str, memory_type: str, content_text: str, content_json: dict[str, JSONValue] | None, metadata_json: dict[str, JSONValue], embedding_model: str | None, embedding_json: list[float] | None, owner_agent_id: str | None, user_id: str | None, session_id: str | None, source_ref: str | None, importance_score: int, expires_time: datetime | None) -> MemoryItem: entity = MemoryItem( scope_type=scope_type, scope_id=scope_id, memory_type=memory_type, content_text=content_text, content_json=content_json, metadata_json=metadata_json, embedding_model=embedding_model, embedding_json=embedding_json, owner_agent_id=owner_agent_id, user_id=user_id, session_id=session_id, source_ref=source_ref, importance_score=importance_score, status="active", expires_time=expires_time) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, scope_type: MemoryScopeType | None = None, scope_id: str | None = None, status: MemoryStatus | None = "active", limit: int = 100) -> list[MemoryItem]: stmt = select(MemoryItem) if scope_type is not None: stmt = stmt.where(MemoryItem.scope_type == scope_type) if scope_id is not None: stmt = stmt.where(MemoryItem.scope_id == scope_id) if status is not None: stmt = stmt.where(MemoryItem.status == status) stmt = stmt.order_by(MemoryItem.created_time.desc()).limit(limit) return list(self.db.scalars(stmt)) def list_filtered( self, *, scope_type: MemoryScopeType | None = None, scope_id: str | None = None, memory_type: str | None = None, status: MemoryStatus | None = "active", owner_agent_id: str | None = None, user_id: str | None = None, session_id: str | None = None, keyword: str | None = None, include_expired: bool = True, offset: int = 0, limit: int = 20) -> tuple[list[MemoryItem], int]: stmt = select(MemoryItem) now = datetime.utcnow() if scope_type is not None: stmt = stmt.where(MemoryItem.scope_type == scope_type) if scope_id is not None: stmt = stmt.where(MemoryItem.scope_id == scope_id) if memory_type is not None: stmt = stmt.where(MemoryItem.memory_type == memory_type) if status is not None: stmt = stmt.where(MemoryItem.status == status) if owner_agent_id is not None: stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id) if user_id is not None: stmt = stmt.where(MemoryItem.user_id == user_id) if session_id is not None: stmt = stmt.where(MemoryItem.session_id == session_id) if not include_expired: stmt = stmt.where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now)) if keyword: pattern = f"%{keyword.strip()}%" stmt = stmt.where(or_( MemoryItem.content_text.ilike(pattern), MemoryItem.memory_type.ilike(pattern), MemoryItem.scope_id.ilike(pattern), MemoryItem.source_ref.ilike(pattern))) total = self.db.scalar(select(func.count()).select_from(stmt.subquery())) or 0 items = list(self.db.scalars( stmt.order_by(MemoryItem.created_time.desc()).offset(offset).limit(limit))) return items, total def search_candidates( self, *, scope_type: MemoryScopeType | None, scope_id: str | None, owner_agent_id: str | None, user_id: str | None, session_id: str | None, limit: int, memory_type: str | None = None) -> list[MemoryItem]: now = datetime.utcnow() stmt = ( select(MemoryItem) .where(MemoryItem.status == "active") .where(or_(MemoryItem.expires_time.is_(None), MemoryItem.expires_time > now)) ) if scope_type is not None: stmt = stmt.where(MemoryItem.scope_type == scope_type) if scope_id is not None: stmt = stmt.where(MemoryItem.scope_id == scope_id) if owner_agent_id is not None: stmt = stmt.where(MemoryItem.owner_agent_id == owner_agent_id) if user_id is not None: stmt = stmt.where(MemoryItem.user_id == user_id) if session_id is not None: stmt = stmt.where(MemoryItem.session_id == session_id) if memory_type is not None: stmt = stmt.where(MemoryItem.memory_type == memory_type) stmt = stmt.order_by( MemoryItem.importance_score.desc(), MemoryItem.created_time.desc()).limit(limit) return list(self.db.scalars(stmt)) def get_by_id(self, *, memory_id: str) -> MemoryItem | None: stmt = ( select(MemoryItem) .where(MemoryItem.id == memory_id) ) return self.db.scalar(stmt) def touch_many(self, *, memory_ids: list[str], accessed_time: datetime) -> None: if not memory_ids: return items = list(self.db.scalars(select(MemoryItem).where(MemoryItem.id.in_(memory_ids)))) for item in items: item.last_accessed_time = accessed_time self.db.commit() def update_status( self, *, memory_id: str, status: MemoryStatus) -> MemoryItem | None: entity = self.get_by_id(memory_id=memory_id) if entity is None: return None entity.status = status self.db.commit() self.db.refresh(entity) return entity def update( self, *, memory_id: str, scope_type: MemoryScopeType | None = None, scope_id: str | None = None, memory_type: str | None = None, content_text: str | None = None, content_json: dict[str, JSONValue] | None = None, metadata_json: dict[str, JSONValue] | None = None, embedding_model: str | None = None, embedding_json: list[float] | None = None, owner_agent_id: str | None = None, user_id: str | None = None, session_id: str | None = None, source_ref: str | None = None, importance_score: int | None = None, expires_time: datetime | None = None) -> MemoryItem | None: entity = self.get_by_id(memory_id=memory_id) if entity is None: return None if scope_type is not None: entity.scope_type = scope_type if scope_id is not None: entity.scope_id = scope_id if memory_type is not None: entity.memory_type = memory_type if content_text is not None: entity.content_text = content_text if content_json is not None: entity.content_json = content_json if metadata_json is not None: entity.metadata_json = metadata_json if embedding_model is not None: entity.embedding_model = embedding_model if embedding_json is not None: entity.embedding_json = embedding_json if owner_agent_id is not None: entity.owner_agent_id = owner_agent_id if user_id is not None: entity.user_id = user_id if session_id is not None: entity.session_id = session_id if source_ref is not None: entity.source_ref = source_ref if importance_score is not None: entity.importance_score = importance_score if expires_time is not None: entity.expires_time = expires_time self.db.commit() self.db.refresh(entity) return entity