repositories.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import Message, RunRequest, Session as SessionModel
  5. class SessionRepository:
  6. def __init__(self, db: Session) -> None:
  7. self.db = db
  8. def create(
  9. self,
  10. *,
  11. tenant_id: str,
  12. app_id: str,
  13. user_id: str,
  14. channel_type: str,
  15. title: str | None,
  16. ) -> SessionModel:
  17. entity = SessionModel(
  18. tenant_id=tenant_id,
  19. app_id=app_id,
  20. user_id=user_id,
  21. channel_type=channel_type,
  22. title=title,
  23. started_time=datetime.utcnow(),
  24. last_active_time=datetime.utcnow(),
  25. )
  26. self.db.add(entity)
  27. self.db.commit()
  28. self.db.refresh(entity)
  29. return entity
  30. def list_by_scope(self, *, tenant_id: str, app_id: str | None = None) -> list[SessionModel]:
  31. stmt = select(SessionModel).where(SessionModel.tenant_id == tenant_id)
  32. if app_id:
  33. stmt = stmt.where(SessionModel.app_id == app_id)
  34. return list(self.db.scalars(stmt))
  35. class MessageRepository:
  36. def __init__(self, db: Session) -> None:
  37. self.db = db
  38. def create(
  39. self,
  40. *,
  41. tenant_id: str,
  42. session_id: str,
  43. turn_id: str | None,
  44. role: str,
  45. content_type: str,
  46. content_text: str | None,
  47. content_json: dict | None,
  48. ) -> Message:
  49. entity = Message(
  50. tenant_id=tenant_id,
  51. session_id=session_id,
  52. turn_id=turn_id,
  53. role=role,
  54. content_type=content_type,
  55. content_text=content_text,
  56. content_json=content_json,
  57. )
  58. self.db.add(entity)
  59. self.db.commit()
  60. self.db.refresh(entity)
  61. return entity
  62. def list_by_session(self, *, tenant_id: str, session_id: str) -> list[Message]:
  63. stmt = (
  64. select(Message)
  65. .where(Message.tenant_id == tenant_id)
  66. .where(Message.session_id == session_id)
  67. .order_by(Message.created_time.asc())
  68. )
  69. return list(self.db.scalars(stmt))
  70. class RunRequestRepository:
  71. def __init__(self, db: Session) -> None:
  72. self.db = db
  73. def create(
  74. self,
  75. *,
  76. tenant_id: str,
  77. session_id: str,
  78. app_version_id: str,
  79. workflow_version_id: str,
  80. trigger_type: str,
  81. request_payload_json: dict | None,
  82. request_status: str,
  83. ) -> RunRequest:
  84. entity = RunRequest(
  85. tenant_id=tenant_id,
  86. session_id=session_id,
  87. app_version_id=app_version_id,
  88. workflow_version_id=workflow_version_id,
  89. trigger_type=trigger_type,
  90. request_payload_json=request_payload_json,
  91. request_status=request_status,
  92. )
  93. self.db.add(entity)
  94. self.db.commit()
  95. self.db.refresh(entity)
  96. return entity
  97. def list_by_session(self, *, tenant_id: str, session_id: str) -> list[RunRequest]:
  98. stmt = (
  99. select(RunRequest)
  100. .where(RunRequest.tenant_id == tenant_id)
  101. .where(RunRequest.session_id == session_id)
  102. .order_by(RunRequest.created_time.desc())
  103. )
  104. return list(self.db.scalars(stmt))