repositories.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import Message, RunRequest
  5. from app.db.models import Session as SessionModel
  6. class SessionRepository:
  7. def __init__(self, db: Session) -> None:
  8. self.db = db
  9. def create(
  10. self,
  11. *,
  12. app_id: str,
  13. user_id: str,
  14. channel_type: str,
  15. title: str | None,
  16. runtime_target_type: str | None,
  17. runtime_target_id: str | None,
  18. runtime_target_config_id: str | None) -> SessionModel:
  19. entity = SessionModel(
  20. app_id=app_id,
  21. user_id=user_id,
  22. channel_type=channel_type,
  23. title=title,
  24. runtime_target_type=runtime_target_type,
  25. runtime_target_id=runtime_target_id,
  26. runtime_target_config_id=runtime_target_config_id,
  27. started_time=datetime.utcnow(),
  28. last_active_time=datetime.utcnow())
  29. self.db.add(entity)
  30. self.db.commit()
  31. self.db.refresh(entity)
  32. return entity
  33. def get_by_id(self, *, session_id: str) -> SessionModel | None:
  34. stmt = select(SessionModel).where(SessionModel.id == session_id)
  35. return self.db.scalars(stmt).first()
  36. def list_by_scope(self, *, app_id: str | None = None) -> list[SessionModel]:
  37. stmt = select(SessionModel)
  38. if app_id:
  39. stmt = stmt.where(SessionModel.app_id == app_id)
  40. return list(self.db.scalars(stmt))
  41. class MessageRepository:
  42. def __init__(self, db: Session) -> None:
  43. self.db = db
  44. def create(
  45. self,
  46. *,
  47. session_id: str,
  48. turn_id: str | None,
  49. role: str,
  50. content_type: str,
  51. content_text: str | None,
  52. content_json: dict | None) -> Message:
  53. entity = Message(
  54. session_id=session_id,
  55. turn_id=turn_id,
  56. role=role,
  57. content_type=content_type,
  58. content_text=content_text,
  59. content_json=content_json)
  60. session = self.db.scalars(
  61. select(SessionModel).where(SessionModel.id == session_id)
  62. ).first()
  63. if session is not None:
  64. session.last_active_time = datetime.utcnow()
  65. self.db.add(session)
  66. self.db.add(entity)
  67. self.db.commit()
  68. self.db.refresh(entity)
  69. return entity
  70. def list_by_session(self, *, session_id: str) -> list[Message]:
  71. stmt = (
  72. select(Message)
  73. .where(Message.session_id == session_id)
  74. .order_by(Message.created_time.asc())
  75. )
  76. return list(self.db.scalars(stmt))
  77. class RunRequestRepository:
  78. def __init__(self, db: Session) -> None:
  79. self.db = db
  80. def create(
  81. self,
  82. *,
  83. session_id: str,
  84. app_config_id: str,
  85. workflow_config_id: str,
  86. trigger_type: str,
  87. request_payload_json: dict | None,
  88. request_status: str) -> RunRequest:
  89. entity = RunRequest(
  90. session_id=session_id,
  91. app_config_id=app_config_id,
  92. workflow_config_id=workflow_config_id,
  93. trigger_type=trigger_type,
  94. request_payload_json=request_payload_json,
  95. request_status=request_status)
  96. self.db.add(entity)
  97. self.db.commit()
  98. self.db.refresh(entity)
  99. return entity
  100. def list_by_session(self, *, session_id: str) -> list[RunRequest]:
  101. stmt = (
  102. select(RunRequest)
  103. .where(RunRequest.session_id == session_id)
  104. .order_by(RunRequest.created_time.desc())
  105. )
  106. return list(self.db.scalars(stmt))
  107. def get_by_id(self, *, run_request_id: str) -> RunRequest | None:
  108. stmt = select(RunRequest).where(RunRequest.id == run_request_id)
  109. return self.db.scalars(stmt).first()
  110. def update(
  111. self,
  112. *,
  113. run_request_id: str,
  114. request_payload_json: dict | None = None,
  115. request_status: str | None = None) -> RunRequest | None:
  116. entity = self.get_by_id(run_request_id=run_request_id)
  117. if entity is None:
  118. return None
  119. if request_payload_json is not None:
  120. entity.request_payload_json = request_payload_json
  121. if request_status is not None:
  122. entity.request_status = request_status
  123. self.db.add(entity)
  124. self.db.commit()
  125. self.db.refresh(entity)
  126. return entity