repositories.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import Message, RunRequest
  5. from app.db.models import Session as SessionModel
  6. class SessionRepository:
  7. def __init__(self, db: Session) -> None:
  8. self.db = db
  9. def create(
  10. self,
  11. *,
  12. app_id: str,
  13. user_id: str,
  14. channel_type: str,
  15. title: str | None) -> SessionModel:
  16. entity = SessionModel(
  17. app_id=app_id,
  18. user_id=user_id,
  19. channel_type=channel_type,
  20. title=title,
  21. started_time=datetime.utcnow(),
  22. last_active_time=datetime.utcnow())
  23. self.db.add(entity)
  24. self.db.commit()
  25. self.db.refresh(entity)
  26. return entity
  27. def list_by_scope(self, *, app_id: str | None = None) -> list[SessionModel]:
  28. stmt = select(SessionModel)
  29. if app_id:
  30. stmt = stmt.where(SessionModel.app_id == app_id)
  31. return list(self.db.scalars(stmt))
  32. class MessageRepository:
  33. def __init__(self, db: Session) -> None:
  34. self.db = db
  35. def create(
  36. self,
  37. *,
  38. session_id: str,
  39. turn_id: str | None,
  40. role: str,
  41. content_type: str,
  42. content_text: str | None,
  43. content_json: dict | None) -> Message:
  44. entity = Message(
  45. session_id=session_id,
  46. turn_id=turn_id,
  47. role=role,
  48. content_type=content_type,
  49. content_text=content_text,
  50. content_json=content_json)
  51. self.db.add(entity)
  52. self.db.commit()
  53. self.db.refresh(entity)
  54. return entity
  55. def list_by_session(self, *, session_id: str) -> list[Message]:
  56. stmt = (
  57. select(Message)
  58. .where(Message.session_id == session_id)
  59. .order_by(Message.created_time.asc())
  60. )
  61. return list(self.db.scalars(stmt))
  62. class RunRequestRepository:
  63. def __init__(self, db: Session) -> None:
  64. self.db = db
  65. def create(
  66. self,
  67. *,
  68. session_id: str,
  69. app_version_id: str,
  70. workflow_version_id: str,
  71. trigger_type: str,
  72. request_payload_json: dict | None,
  73. request_status: str) -> RunRequest:
  74. entity = RunRequest(
  75. session_id=session_id,
  76. app_version_id=app_version_id,
  77. workflow_version_id=workflow_version_id,
  78. trigger_type=trigger_type,
  79. request_payload_json=request_payload_json,
  80. request_status=request_status)
  81. self.db.add(entity)
  82. self.db.commit()
  83. self.db.refresh(entity)
  84. return entity
  85. def list_by_session(self, *, session_id: str) -> list[RunRequest]:
  86. stmt = (
  87. select(RunRequest)
  88. .where(RunRequest.session_id == session_id)
  89. .order_by(RunRequest.created_time.desc())
  90. )
  91. return list(self.db.scalars(stmt))