from datetime import datetime from sqlalchemy import select from sqlalchemy.orm import Session from app.db.models import Message, RunRequest, Session as SessionModel class SessionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, app_id: str, user_id: str, channel_type: str, title: str | None, ) -> SessionModel: entity = SessionModel( tenant_id=tenant_id, app_id=app_id, user_id=user_id, channel_type=channel_type, title=title, started_time=datetime.utcnow(), last_active_time=datetime.utcnow(), ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope(self, *, tenant_id: str, app_id: str | None = None) -> list[SessionModel]: stmt = select(SessionModel).where(SessionModel.tenant_id == tenant_id) if app_id: stmt = stmt.where(SessionModel.app_id == app_id) return list(self.db.scalars(stmt)) class MessageRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, session_id: str, turn_id: str | None, role: str, content_type: str, content_text: str | None, content_json: dict | None, ) -> Message: entity = Message( tenant_id=tenant_id, session_id=session_id, turn_id=turn_id, role=role, content_type=content_type, content_text=content_text, content_json=content_json, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_session(self, *, tenant_id: str, session_id: str) -> list[Message]: stmt = ( select(Message) .where(Message.tenant_id == tenant_id) .where(Message.session_id == session_id) .order_by(Message.created_time.asc()) ) return list(self.db.scalars(stmt)) class RunRequestRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, session_id: str, app_version_id: str, workflow_version_id: str, trigger_type: str, request_payload_json: dict | None, request_status: str, ) -> RunRequest: entity = RunRequest( tenant_id=tenant_id, session_id=session_id, app_version_id=app_version_id, workflow_version_id=workflow_version_id, trigger_type=trigger_type, request_payload_json=request_payload_json, request_status=request_status, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_session(self, *, tenant_id: str, session_id: str) -> list[RunRequest]: stmt = ( select(RunRequest) .where(RunRequest.tenant_id == tenant_id) .where(RunRequest.session_id == session_id) .order_by(RunRequest.created_time.desc()) ) return list(self.db.scalars(stmt))