repositories.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from datetime import datetime
  2. from sqlalchemy import func, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import TeamRunStatus, TeamStatus, TeamVersionStatus
  5. from core_shared import JSONValue
  6. from app.db.models import TeamDefinition, TeamRun, TeamVersion
  7. class TeamDefinitionRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. code: str,
  15. name: str,
  16. description: str | None,
  17. team_type: str,
  18. owner_user_id: str | None,
  19. metadata_json: dict[str, JSONValue] | None,
  20. ) -> TeamDefinition:
  21. entity = TeamDefinition(
  22. tenant_id=tenant_id,
  23. code=code,
  24. name=name,
  25. description=description,
  26. team_type=team_type,
  27. owner_user_id=owner_user_id,
  28. metadata_json=metadata_json,
  29. )
  30. self.db.add(entity)
  31. self.db.commit()
  32. self.db.refresh(entity)
  33. return entity
  34. def list_by_tenant(self, *, tenant_id: str) -> list[TeamDefinition]:
  35. stmt = (
  36. select(TeamDefinition)
  37. .where(TeamDefinition.tenant_id == tenant_id)
  38. .order_by(TeamDefinition.created_time.desc())
  39. )
  40. return list(self.db.scalars(stmt))
  41. def get_by_id(self, *, tenant_id: str, team_id: str) -> TeamDefinition | None:
  42. stmt = (
  43. select(TeamDefinition)
  44. .where(TeamDefinition.tenant_id == tenant_id)
  45. .where(TeamDefinition.id == team_id)
  46. )
  47. return self.db.scalar(stmt)
  48. def update_status(
  49. self,
  50. *,
  51. tenant_id: str,
  52. team_id: str,
  53. status: TeamStatus,
  54. ) -> TeamDefinition | None:
  55. entity = self.get_by_id(tenant_id=tenant_id, team_id=team_id)
  56. if entity is None:
  57. return None
  58. entity.status = status
  59. self.db.commit()
  60. self.db.refresh(entity)
  61. return entity
  62. class TeamVersionRepository:
  63. def __init__(self, db: Session) -> None:
  64. self.db = db
  65. def create(
  66. self,
  67. *,
  68. tenant_id: str,
  69. team_id: str,
  70. status: TeamVersionStatus,
  71. coordination_mode: str,
  72. objective: str | None,
  73. member_refs_json: list[dict[str, JSONValue]],
  74. policy_json: dict[str, JSONValue],
  75. ) -> TeamVersion:
  76. version_no = self._next_version_no(team_id)
  77. entity = TeamVersion(
  78. tenant_id=tenant_id,
  79. team_id=team_id,
  80. version_no=version_no,
  81. status=status,
  82. coordination_mode=coordination_mode,
  83. objective=objective,
  84. member_refs_json=member_refs_json,
  85. policy_json=policy_json,
  86. published_time=datetime.utcnow() if status == "published" else None,
  87. )
  88. self.db.add(entity)
  89. self.db.commit()
  90. self.db.refresh(entity)
  91. return entity
  92. def list_by_team(self, *, tenant_id: str, team_id: str) -> list[TeamVersion]:
  93. stmt = (
  94. select(TeamVersion)
  95. .where(TeamVersion.tenant_id == tenant_id)
  96. .where(TeamVersion.team_id == team_id)
  97. .order_by(TeamVersion.version_no.desc())
  98. )
  99. return list(self.db.scalars(stmt))
  100. def get_by_id(self, *, tenant_id: str, team_version_id: str) -> TeamVersion | None:
  101. stmt = (
  102. select(TeamVersion)
  103. .where(TeamVersion.tenant_id == tenant_id)
  104. .where(TeamVersion.id == team_version_id)
  105. )
  106. return self.db.scalar(stmt)
  107. def get_latest_published(self, *, tenant_id: str, team_id: str) -> TeamVersion | None:
  108. stmt = (
  109. select(TeamVersion)
  110. .where(TeamVersion.tenant_id == tenant_id)
  111. .where(TeamVersion.team_id == team_id)
  112. .where(TeamVersion.status == "published")
  113. .order_by(TeamVersion.version_no.desc())
  114. .limit(1)
  115. )
  116. return self.db.scalar(stmt)
  117. def _next_version_no(self, team_id: str) -> int:
  118. stmt = select(func.max(TeamVersion.version_no)).where(TeamVersion.team_id == team_id)
  119. current_max = self.db.scalar(stmt)
  120. return (current_max or 0) + 1
  121. class TeamRunRepository:
  122. def __init__(self, db: Session) -> None:
  123. self.db = db
  124. def create(
  125. self,
  126. *,
  127. tenant_id: str,
  128. team_id: str,
  129. team_version_id: str,
  130. session_id: str | None,
  131. input_text: str | None,
  132. input_json: dict[str, JSONValue] | None,
  133. ) -> TeamRun:
  134. now = datetime.utcnow()
  135. entity = TeamRun(
  136. tenant_id=tenant_id,
  137. team_id=team_id,
  138. team_version_id=team_version_id,
  139. session_id=session_id,
  140. input_text=input_text,
  141. input_json=input_json,
  142. status="queued",
  143. queued_time=now,
  144. )
  145. self.db.add(entity)
  146. self.db.commit()
  147. self.db.refresh(entity)
  148. return entity
  149. def list_by_scope(
  150. self,
  151. *,
  152. tenant_id: str,
  153. team_id: str | None = None,
  154. session_id: str | None = None,
  155. ) -> list[TeamRun]:
  156. stmt = select(TeamRun).where(TeamRun.tenant_id == tenant_id)
  157. if team_id is not None:
  158. stmt = stmt.where(TeamRun.team_id == team_id)
  159. if session_id is not None:
  160. stmt = stmt.where(TeamRun.session_id == session_id)
  161. stmt = stmt.order_by(TeamRun.created_time.desc())
  162. return list(self.db.scalars(stmt))
  163. def get_by_id(self, *, tenant_id: str, team_run_id: str) -> TeamRun | None:
  164. stmt = (
  165. select(TeamRun)
  166. .where(TeamRun.tenant_id == tenant_id)
  167. .where(TeamRun.id == team_run_id)
  168. )
  169. return self.db.scalar(stmt)
  170. def claim_next_queued(
  171. self,
  172. *,
  173. worker_key: str,
  174. lease_expire_time: datetime,
  175. ) -> TeamRun | None:
  176. stmt = (
  177. select(TeamRun)
  178. .where(TeamRun.status == "queued")
  179. .order_by(TeamRun.created_time.asc())
  180. .with_for_update(skip_locked=True)
  181. .limit(1)
  182. )
  183. entity = self.db.scalar(stmt)
  184. if entity is None:
  185. return None
  186. now = datetime.utcnow()
  187. entity.status = "running"
  188. entity.worker_key = worker_key
  189. entity.started_time = entity.started_time or now
  190. entity.lease_expire_time = lease_expire_time
  191. self.db.commit()
  192. self.db.refresh(entity)
  193. return entity
  194. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  195. stmt = (
  196. select(TeamRun)
  197. .where(TeamRun.status == "running")
  198. .where(TeamRun.lease_expire_time.is_not(None))
  199. .where(TeamRun.lease_expire_time <= now_time)
  200. .order_by(TeamRun.lease_expire_time.asc())
  201. .limit(max_items)
  202. )
  203. entities = list(self.db.scalars(stmt))
  204. for entity in entities:
  205. entity.status = "queued"
  206. entity.worker_key = None
  207. entity.lease_expire_time = None
  208. entity.queued_time = now_time
  209. entity.started_time = None
  210. entity.finished_time = None
  211. if entities:
  212. self.db.commit()
  213. return len(entities)
  214. def update_status(
  215. self,
  216. *,
  217. team_run_id: str,
  218. status: TeamRunStatus,
  219. worker_key: str | None = None,
  220. output_text: str | None = None,
  221. output_json: dict[str, JSONValue] | None = None,
  222. error_code: str | None = None,
  223. error_message: str | None = None,
  224. ) -> TeamRun | None:
  225. entity = self.db.get(TeamRun, team_run_id)
  226. if entity is None:
  227. return None
  228. now = datetime.utcnow()
  229. entity.status = status
  230. entity.worker_key = worker_key
  231. entity.output_text = output_text
  232. entity.output_json = output_json
  233. entity.error_code = error_code
  234. entity.error_message = error_message
  235. if status == "running" and entity.started_time is None:
  236. entity.started_time = now
  237. if status != "running":
  238. entity.lease_expire_time = None
  239. if status in {"completed", "failed", "cancelled"}:
  240. entity.finished_time = now
  241. self.db.commit()
  242. self.db.refresh(entity)
  243. return entity