| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- from datetime import datetime
- from sqlalchemy import select
- from sqlalchemy.orm import Session
- from app.db.models import Message, RunRequest
- from app.db.models import Session as SessionModel
- class SessionRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- app_id: str,
- user_id: str,
- channel_type: str,
- title: str | None,
- runtime_target_type: str | None,
- runtime_target_id: str | None,
- runtime_target_config_id: str | None) -> SessionModel:
- entity = SessionModel(
- app_id=app_id,
- user_id=user_id,
- channel_type=channel_type,
- title=title,
- runtime_target_type=runtime_target_type,
- runtime_target_id=runtime_target_id,
- runtime_target_config_id=runtime_target_config_id,
- started_time=datetime.utcnow(),
- last_active_time=datetime.utcnow())
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def get_by_id(self, *, session_id: str) -> SessionModel | None:
- stmt = select(SessionModel).where(SessionModel.id == session_id)
- return self.db.scalars(stmt).first()
- def list_by_scope(self, *, app_id: str | None = None) -> list[SessionModel]:
- stmt = select(SessionModel)
- if app_id:
- stmt = stmt.where(SessionModel.app_id == app_id)
- return list(self.db.scalars(stmt))
- class MessageRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- session_id: str,
- turn_id: str | None,
- role: str,
- content_type: str,
- content_text: str | None,
- content_json: dict | None) -> Message:
- entity = Message(
- session_id=session_id,
- turn_id=turn_id,
- role=role,
- content_type=content_type,
- content_text=content_text,
- content_json=content_json)
- session = self.db.scalars(
- select(SessionModel).where(SessionModel.id == session_id)
- ).first()
- if session is not None:
- session.last_active_time = datetime.utcnow()
- self.db.add(session)
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_session(self, *, session_id: str) -> list[Message]:
- stmt = (
- select(Message)
- .where(Message.session_id == session_id)
- .order_by(Message.created_time.asc())
- )
- return list(self.db.scalars(stmt))
- class RunRequestRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- session_id: str,
- app_config_id: str,
- workflow_config_id: str,
- trigger_type: str,
- request_payload_json: dict | None,
- request_status: str) -> RunRequest:
- entity = RunRequest(
- session_id=session_id,
- app_config_id=app_config_id,
- workflow_config_id=workflow_config_id,
- trigger_type=trigger_type,
- request_payload_json=request_payload_json,
- request_status=request_status)
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_session(self, *, session_id: str) -> list[RunRequest]:
- stmt = (
- select(RunRequest)
- .where(RunRequest.session_id == session_id)
- .order_by(RunRequest.created_time.desc())
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, *, run_request_id: str) -> RunRequest | None:
- stmt = select(RunRequest).where(RunRequest.id == run_request_id)
- return self.db.scalars(stmt).first()
- def update(
- self,
- *,
- run_request_id: str,
- request_payload_json: dict | None = None,
- request_status: str | None = None) -> RunRequest | None:
- entity = self.get_by_id(run_request_id=run_request_id)
- if entity is None:
- return None
- if request_payload_json is not None:
- entity.request_payload_json = request_payload_json
- if request_status is not None:
- entity.request_status = request_status
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
|