| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- from datetime import datetime
- from sqlalchemy import select
- from sqlalchemy.orm import Session
- from app.db.models import Message, RunRequest, Session as SessionModel
- class SessionRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- app_id: str,
- user_id: str,
- channel_type: str,
- title: str | None,
- ) -> SessionModel:
- entity = SessionModel(
- tenant_id=tenant_id,
- app_id=app_id,
- user_id=user_id,
- channel_type=channel_type,
- title=title,
- started_time=datetime.utcnow(),
- last_active_time=datetime.utcnow(),
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(self, *, tenant_id: str, app_id: str | None = None) -> list[SessionModel]:
- stmt = select(SessionModel).where(SessionModel.tenant_id == tenant_id)
- if app_id:
- stmt = stmt.where(SessionModel.app_id == app_id)
- return list(self.db.scalars(stmt))
- class MessageRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- session_id: str,
- turn_id: str | None,
- role: str,
- content_type: str,
- content_text: str | None,
- content_json: dict | None,
- ) -> Message:
- entity = Message(
- tenant_id=tenant_id,
- session_id=session_id,
- turn_id=turn_id,
- role=role,
- content_type=content_type,
- content_text=content_text,
- content_json=content_json,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_session(self, *, tenant_id: str, session_id: str) -> list[Message]:
- stmt = (
- select(Message)
- .where(Message.tenant_id == tenant_id)
- .where(Message.session_id == session_id)
- .order_by(Message.created_time.asc())
- )
- return list(self.db.scalars(stmt))
- class RunRequestRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- session_id: str,
- app_version_id: str,
- workflow_version_id: str,
- trigger_type: str,
- request_payload_json: dict | None,
- request_status: str,
- ) -> RunRequest:
- entity = RunRequest(
- tenant_id=tenant_id,
- session_id=session_id,
- app_version_id=app_version_id,
- workflow_version_id=workflow_version_id,
- trigger_type=trigger_type,
- request_payload_json=request_payload_json,
- request_status=request_status,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_session(self, *, tenant_id: str, session_id: str) -> list[RunRequest]:
- stmt = (
- select(RunRequest)
- .where(RunRequest.tenant_id == tenant_id)
- .where(RunRequest.session_id == session_id)
- .order_by(RunRequest.created_time.desc())
- )
- return list(self.db.scalars(stmt))
|