repositories.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from datetime import datetime
  2. from sqlalchemy import func, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import TeamRunStatus, TeamStatus, TeamVersionStatus
  5. from core_shared import JSONValue
  6. from app.db.models import TeamDefinition, TeamRun, TeamVersion
  7. class TeamDefinitionRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. code: str,
  14. name: str,
  15. description: str | None,
  16. team_type: str,
  17. owner_user_id: str | None,
  18. metadata_json: dict[str, JSONValue] | None) -> TeamDefinition:
  19. entity = TeamDefinition(
  20. code=code,
  21. name=name,
  22. description=description,
  23. team_type=team_type,
  24. owner_user_id=owner_user_id,
  25. metadata_json=metadata_json)
  26. self.db.add(entity)
  27. self.db.commit()
  28. self.db.refresh(entity)
  29. return entity
  30. def list_all(self) -> list[TeamDefinition]:
  31. stmt = (
  32. select(TeamDefinition)
  33. .order_by(TeamDefinition.created_time.desc())
  34. )
  35. return list(self.db.scalars(stmt))
  36. def get_by_id(self, *, team_id: str) -> TeamDefinition | None:
  37. stmt = (
  38. select(TeamDefinition)
  39. .where(TeamDefinition.id == team_id)
  40. )
  41. return self.db.scalar(stmt)
  42. def update_status(
  43. self,
  44. *,
  45. team_id: str,
  46. status: TeamStatus) -> TeamDefinition | None:
  47. entity = self.get_by_id(team_id=team_id)
  48. if entity is None:
  49. return None
  50. entity.status = status
  51. self.db.commit()
  52. self.db.refresh(entity)
  53. return entity
  54. class TeamVersionRepository:
  55. def __init__(self, db: Session) -> None:
  56. self.db = db
  57. def create(
  58. self,
  59. *,
  60. team_id: str,
  61. status: TeamVersionStatus,
  62. coordination_mode: str,
  63. objective: str | None,
  64. member_refs_json: list[dict[str, JSONValue]],
  65. policy_json: dict[str, JSONValue]) -> TeamVersion:
  66. version_no = self._next_version_no(team_id)
  67. entity = TeamVersion(
  68. team_id=team_id,
  69. version_no=version_no,
  70. status=status,
  71. coordination_mode=coordination_mode,
  72. objective=objective,
  73. member_refs_json=member_refs_json,
  74. policy_json=policy_json,
  75. published_time=datetime.utcnow() if status == "published" else None)
  76. self.db.add(entity)
  77. self.db.commit()
  78. self.db.refresh(entity)
  79. return entity
  80. def list_by_team(self, *, team_id: str) -> list[TeamVersion]:
  81. stmt = (
  82. select(TeamVersion)
  83. .where(TeamVersion.team_id == team_id)
  84. .order_by(TeamVersion.version_no.desc())
  85. )
  86. return list(self.db.scalars(stmt))
  87. def get_by_id(self, *, team_version_id: str) -> TeamVersion | None:
  88. stmt = (
  89. select(TeamVersion)
  90. .where(TeamVersion.id == team_version_id)
  91. )
  92. return self.db.scalar(stmt)
  93. def get_latest_published(self, *, team_id: str) -> TeamVersion | None:
  94. stmt = (
  95. select(TeamVersion)
  96. .where(TeamVersion.team_id == team_id)
  97. .where(TeamVersion.status == "published")
  98. .order_by(TeamVersion.version_no.desc())
  99. .limit(1)
  100. )
  101. return self.db.scalar(stmt)
  102. def _next_version_no(self, team_id: str) -> int:
  103. stmt = select(func.max(TeamVersion.version_no)).where(TeamVersion.team_id == team_id)
  104. current_max = self.db.scalar(stmt)
  105. return (current_max or 0) + 1
  106. class TeamRunRepository:
  107. def __init__(self, db: Session) -> None:
  108. self.db = db
  109. def create(
  110. self,
  111. *,
  112. team_id: str,
  113. team_version_id: str,
  114. session_id: str | None,
  115. input_text: str | None,
  116. input_json: dict[str, JSONValue] | None) -> TeamRun:
  117. now = datetime.utcnow()
  118. entity = TeamRun(
  119. team_id=team_id,
  120. team_version_id=team_version_id,
  121. session_id=session_id,
  122. input_text=input_text,
  123. input_json=input_json,
  124. status="queued",
  125. queued_time=now)
  126. self.db.add(entity)
  127. self.db.commit()
  128. self.db.refresh(entity)
  129. return entity
  130. def list_by_scope(
  131. self,
  132. *,
  133. team_id: str | None = None,
  134. session_id: str | None = None) -> list[TeamRun]:
  135. stmt = select(TeamRun)
  136. if team_id is not None:
  137. stmt = stmt.where(TeamRun.team_id == team_id)
  138. if session_id is not None:
  139. stmt = stmt.where(TeamRun.session_id == session_id)
  140. stmt = stmt.order_by(TeamRun.created_time.desc())
  141. return list(self.db.scalars(stmt))
  142. def get_by_id(self, *, team_run_id: str) -> TeamRun | None:
  143. stmt = (
  144. select(TeamRun)
  145. .where(TeamRun.id == team_run_id)
  146. )
  147. return self.db.scalar(stmt)
  148. def claim_next_queued(
  149. self,
  150. *,
  151. worker_key: str,
  152. lease_expire_time: datetime) -> TeamRun | None:
  153. stmt = (
  154. select(TeamRun)
  155. .where(TeamRun.status == "queued")
  156. .order_by(TeamRun.created_time.asc())
  157. .with_for_update(skip_locked=True)
  158. .limit(1)
  159. )
  160. entity = self.db.scalar(stmt)
  161. if entity is None:
  162. return None
  163. now = datetime.utcnow()
  164. entity.status = "running"
  165. entity.worker_key = worker_key
  166. entity.started_time = entity.started_time or now
  167. entity.lease_expire_time = lease_expire_time
  168. self.db.commit()
  169. self.db.refresh(entity)
  170. return entity
  171. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  172. stmt = (
  173. select(TeamRun)
  174. .where(TeamRun.status == "running")
  175. .where(TeamRun.lease_expire_time.is_not(None))
  176. .where(TeamRun.lease_expire_time <= now_time)
  177. .order_by(TeamRun.lease_expire_time.asc())
  178. .limit(max_items)
  179. )
  180. entities = list(self.db.scalars(stmt))
  181. for entity in entities:
  182. entity.status = "queued"
  183. entity.worker_key = None
  184. entity.lease_expire_time = None
  185. entity.queued_time = now_time
  186. entity.started_time = None
  187. entity.finished_time = None
  188. if entities:
  189. self.db.commit()
  190. return len(entities)
  191. def update_status(
  192. self,
  193. *,
  194. team_run_id: str,
  195. status: TeamRunStatus,
  196. worker_key: str | None = None,
  197. output_text: str | None = None,
  198. output_json: dict[str, JSONValue] | None = None,
  199. error_code: str | None = None,
  200. error_message: str | None = None) -> TeamRun | None:
  201. entity = self.db.get(TeamRun, team_run_id)
  202. if entity is None:
  203. return None
  204. now = datetime.utcnow()
  205. entity.status = status
  206. entity.worker_key = worker_key
  207. entity.output_text = output_text
  208. entity.output_json = output_json
  209. entity.error_code = error_code
  210. entity.error_message = error_message
  211. if status == "running" and entity.started_time is None:
  212. entity.started_time = now
  213. if status != "running":
  214. entity.lease_expire_time = None
  215. if status in {"completed", "failed", "cancelled"}:
  216. entity.finished_time = now
  217. self.db.commit()
  218. self.db.refresh(entity)
  219. return entity