repositories.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from datetime import datetime
  2. from sqlalchemy import func, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import SkillInstallStatus, SkillRunStatus, SkillStatus, SkillVersionStatus
  5. from core_shared import JSONValue
  6. from app.db.models import SkillDefinition, SkillInstallation, SkillRun, SkillVersion
  7. class SkillDefinitionRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. code: str,
  15. name: str,
  16. skill_type: str,
  17. description: str | None,
  18. owner_user_id: str | None,
  19. metadata_json: dict[str, JSONValue] | None,
  20. ) -> SkillDefinition:
  21. entity = SkillDefinition(
  22. tenant_id=tenant_id,
  23. code=code,
  24. name=name,
  25. skill_type=skill_type,
  26. description=description,
  27. owner_user_id=owner_user_id,
  28. metadata_json=metadata_json,
  29. )
  30. self.db.add(entity)
  31. self.db.commit()
  32. self.db.refresh(entity)
  33. return entity
  34. def list_by_tenant(self, *, tenant_id: str) -> list[SkillDefinition]:
  35. stmt = (
  36. select(SkillDefinition)
  37. .where(SkillDefinition.tenant_id == tenant_id)
  38. .order_by(SkillDefinition.created_time.desc())
  39. )
  40. return list(self.db.scalars(stmt))
  41. def get_by_id(self, *, tenant_id: str, skill_id: str) -> SkillDefinition | None:
  42. stmt = (
  43. select(SkillDefinition)
  44. .where(SkillDefinition.tenant_id == tenant_id)
  45. .where(SkillDefinition.id == skill_id)
  46. )
  47. return self.db.scalar(stmt)
  48. def update_status(
  49. self,
  50. *,
  51. tenant_id: str,
  52. skill_id: str,
  53. status: SkillStatus,
  54. ) -> SkillDefinition | None:
  55. entity = self.get_by_id(tenant_id=tenant_id, skill_id=skill_id)
  56. if entity is None:
  57. return None
  58. entity.status = status
  59. self.db.commit()
  60. self.db.refresh(entity)
  61. return entity
  62. class SkillVersionRepository:
  63. def __init__(self, db: Session) -> None:
  64. self.db = db
  65. def create(
  66. self,
  67. *,
  68. tenant_id: str,
  69. skill_id: str,
  70. status: SkillVersionStatus,
  71. runtime_type: str,
  72. entrypoint: str | None,
  73. parameter_schema_json: dict[str, JSONValue],
  74. output_schema_json: dict[str, JSONValue],
  75. implementation_json: dict[str, JSONValue],
  76. ) -> SkillVersion:
  77. entity = SkillVersion(
  78. tenant_id=tenant_id,
  79. skill_id=skill_id,
  80. version_no=self._next_version_no(skill_id),
  81. status=status,
  82. runtime_type=runtime_type,
  83. entrypoint=entrypoint,
  84. parameter_schema_json=parameter_schema_json,
  85. output_schema_json=output_schema_json,
  86. implementation_json=implementation_json,
  87. published_time=datetime.utcnow() if status == "published" else None,
  88. )
  89. self.db.add(entity)
  90. self.db.commit()
  91. self.db.refresh(entity)
  92. return entity
  93. def list_by_skill(self, *, tenant_id: str, skill_id: str) -> list[SkillVersion]:
  94. stmt = (
  95. select(SkillVersion)
  96. .where(SkillVersion.tenant_id == tenant_id)
  97. .where(SkillVersion.skill_id == skill_id)
  98. .order_by(SkillVersion.version_no.desc())
  99. )
  100. return list(self.db.scalars(stmt))
  101. def get_by_id(self, *, tenant_id: str, skill_version_id: str) -> SkillVersion | None:
  102. stmt = (
  103. select(SkillVersion)
  104. .where(SkillVersion.tenant_id == tenant_id)
  105. .where(SkillVersion.id == skill_version_id)
  106. )
  107. return self.db.scalar(stmt)
  108. def get_latest_published(self, *, tenant_id: str, skill_id: str) -> SkillVersion | None:
  109. stmt = (
  110. select(SkillVersion)
  111. .where(SkillVersion.tenant_id == tenant_id)
  112. .where(SkillVersion.skill_id == skill_id)
  113. .where(SkillVersion.status == "published")
  114. .order_by(SkillVersion.version_no.desc())
  115. .limit(1)
  116. )
  117. return self.db.scalar(stmt)
  118. def _next_version_no(self, skill_id: str) -> int:
  119. stmt = select(func.max(SkillVersion.version_no)).where(SkillVersion.skill_id == skill_id)
  120. return (self.db.scalar(stmt) or 0) + 1
  121. class SkillInstallationRepository:
  122. def __init__(self, db: Session) -> None:
  123. self.db = db
  124. def create(
  125. self,
  126. *,
  127. tenant_id: str,
  128. skill_id: str,
  129. skill_version_id: str,
  130. install_scope: str,
  131. scope_id: str,
  132. config_json: dict[str, JSONValue],
  133. installed_by: str | None,
  134. ) -> SkillInstallation:
  135. entity = SkillInstallation(
  136. tenant_id=tenant_id,
  137. skill_id=skill_id,
  138. skill_version_id=skill_version_id,
  139. install_scope=install_scope,
  140. scope_id=scope_id,
  141. config_json=config_json,
  142. status="installed",
  143. installed_by=installed_by,
  144. installed_time=datetime.utcnow(),
  145. )
  146. self.db.add(entity)
  147. self.db.commit()
  148. self.db.refresh(entity)
  149. return entity
  150. def list_by_scope(
  151. self,
  152. *,
  153. tenant_id: str,
  154. install_scope: str | None = None,
  155. scope_id: str | None = None,
  156. ) -> list[SkillInstallation]:
  157. stmt = select(SkillInstallation).where(SkillInstallation.tenant_id == tenant_id)
  158. if install_scope is not None:
  159. stmt = stmt.where(SkillInstallation.install_scope == install_scope)
  160. if scope_id is not None:
  161. stmt = stmt.where(SkillInstallation.scope_id == scope_id)
  162. stmt = stmt.order_by(SkillInstallation.created_time.desc())
  163. return list(self.db.scalars(stmt))
  164. def get_by_id(self, *, tenant_id: str, installation_id: str) -> SkillInstallation | None:
  165. stmt = (
  166. select(SkillInstallation)
  167. .where(SkillInstallation.tenant_id == tenant_id)
  168. .where(SkillInstallation.id == installation_id)
  169. )
  170. return self.db.scalar(stmt)
  171. def update_status(
  172. self,
  173. *,
  174. tenant_id: str,
  175. installation_id: str,
  176. status: SkillInstallStatus,
  177. ) -> SkillInstallation | None:
  178. entity = self.get_by_id(tenant_id=tenant_id, installation_id=installation_id)
  179. if entity is None:
  180. return None
  181. entity.status = status
  182. self.db.commit()
  183. self.db.refresh(entity)
  184. return entity
  185. class SkillRunRepository:
  186. def __init__(self, db: Session) -> None:
  187. self.db = db
  188. def create(
  189. self,
  190. *,
  191. tenant_id: str,
  192. skill_id: str,
  193. skill_version_id: str,
  194. installation_id: str | None,
  195. input_json: dict[str, JSONValue],
  196. ) -> SkillRun:
  197. entity = SkillRun(
  198. tenant_id=tenant_id,
  199. skill_id=skill_id,
  200. skill_version_id=skill_version_id,
  201. installation_id=installation_id,
  202. input_json=input_json,
  203. status="queued",
  204. )
  205. self.db.add(entity)
  206. self.db.commit()
  207. self.db.refresh(entity)
  208. return entity
  209. def get_by_id(self, *, tenant_id: str, skill_run_id: str) -> SkillRun | None:
  210. stmt = (
  211. select(SkillRun)
  212. .where(SkillRun.tenant_id == tenant_id)
  213. .where(SkillRun.id == skill_run_id)
  214. )
  215. return self.db.scalar(stmt)
  216. def update_status(
  217. self,
  218. *,
  219. skill_run_id: str,
  220. status: SkillRunStatus,
  221. worker_key: str | None = None,
  222. output_json: dict[str, JSONValue] | None = None,
  223. output_text: str | None = None,
  224. error_code: str | None = None,
  225. error_message: str | None = None,
  226. ) -> SkillRun | None:
  227. entity = self.db.get(SkillRun, skill_run_id)
  228. if entity is None:
  229. return None
  230. now = datetime.utcnow()
  231. entity.status = status
  232. entity.worker_key = worker_key
  233. entity.output_json = output_json
  234. entity.output_text = output_text
  235. entity.error_code = error_code
  236. entity.error_message = error_message
  237. if status == "running" and entity.started_time is None:
  238. entity.started_time = now
  239. if status in {"completed", "failed", "cancelled"}:
  240. entity.finished_time = now
  241. self.db.commit()
  242. self.db.refresh(entity)
  243. return entity