from datetime import datetime from sqlalchemy import select from sqlalchemy.orm import Session from app.db.models import Message, RunRequest from app.db.models import Session as SessionModel class SessionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, app_id: str, user_id: str, channel_type: str, title: str | None) -> SessionModel: entity = SessionModel( app_id=app_id, user_id=user_id, channel_type=channel_type, title=title, started_time=datetime.utcnow(), last_active_time=datetime.utcnow()) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope(self, *, app_id: str | None = None) -> list[SessionModel]: stmt = select(SessionModel) if app_id: stmt = stmt.where(SessionModel.app_id == app_id) return list(self.db.scalars(stmt)) class MessageRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, session_id: str, turn_id: str | None, role: str, content_type: str, content_text: str | None, content_json: dict | None) -> Message: entity = Message( session_id=session_id, turn_id=turn_id, role=role, content_type=content_type, content_text=content_text, content_json=content_json) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_session(self, *, session_id: str) -> list[Message]: stmt = ( select(Message) .where(Message.session_id == session_id) .order_by(Message.created_time.asc()) ) return list(self.db.scalars(stmt)) class RunRequestRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, session_id: str, app_config_id: str, workflow_config_id: str, trigger_type: str, request_payload_json: dict | None, request_status: str) -> RunRequest: entity = RunRequest( session_id=session_id, app_config_id=app_config_id, workflow_config_id=workflow_config_id, trigger_type=trigger_type, request_payload_json=request_payload_json, request_status=request_status) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_session(self, *, session_id: str) -> list[RunRequest]: stmt = ( select(RunRequest) .where(RunRequest.session_id == session_id) .order_by(RunRequest.created_time.desc()) ) return list(self.db.scalars(stmt))