repositories.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from datetime import datetime
  2. from sqlalchemy import func, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import SkillInstallStatus, SkillRunStatus, SkillStatus, SkillVersionStatus
  5. from core_shared import JSONValue
  6. from app.db.models import SkillDefinition, SkillInstallation, SkillRun, SkillVersion
  7. class SkillDefinitionRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. code: str,
  14. name: str,
  15. skill_type: str,
  16. description: str | None,
  17. owner_user_id: str | None,
  18. metadata_json: dict[str, JSONValue] | None) -> SkillDefinition:
  19. entity = SkillDefinition(
  20. code=code,
  21. name=name,
  22. skill_type=skill_type,
  23. description=description,
  24. owner_user_id=owner_user_id,
  25. metadata_json=metadata_json)
  26. self.db.add(entity)
  27. self.db.commit()
  28. self.db.refresh(entity)
  29. return entity
  30. def list_all(self) -> list[SkillDefinition]:
  31. stmt = (
  32. select(SkillDefinition)
  33. .order_by(SkillDefinition.created_time.desc())
  34. )
  35. return list(self.db.scalars(stmt))
  36. def get_by_id(self, *, skill_id: str) -> SkillDefinition | None:
  37. stmt = (
  38. select(SkillDefinition)
  39. .where(SkillDefinition.id == skill_id)
  40. )
  41. return self.db.scalar(stmt)
  42. def update_status(
  43. self,
  44. *,
  45. skill_id: str,
  46. status: SkillStatus) -> SkillDefinition | None:
  47. entity = self.get_by_id(skill_id=skill_id)
  48. if entity is None:
  49. return None
  50. entity.status = status
  51. self.db.commit()
  52. self.db.refresh(entity)
  53. return entity
  54. class SkillVersionRepository:
  55. def __init__(self, db: Session) -> None:
  56. self.db = db
  57. def create(
  58. self,
  59. *,
  60. skill_id: str,
  61. status: SkillVersionStatus,
  62. runtime_type: str,
  63. entrypoint: str | None,
  64. parameter_schema_json: dict[str, JSONValue],
  65. output_schema_json: dict[str, JSONValue],
  66. implementation_json: dict[str, JSONValue]) -> SkillVersion:
  67. entity = SkillVersion(
  68. skill_id=skill_id,
  69. version_no=self._next_version_no(skill_id),
  70. status=status,
  71. runtime_type=runtime_type,
  72. entrypoint=entrypoint,
  73. parameter_schema_json=parameter_schema_json,
  74. output_schema_json=output_schema_json,
  75. implementation_json=implementation_json,
  76. published_time=datetime.utcnow() if status == "published" else None)
  77. self.db.add(entity)
  78. self.db.commit()
  79. self.db.refresh(entity)
  80. return entity
  81. def list_by_skill(self, *, skill_id: str) -> list[SkillVersion]:
  82. stmt = (
  83. select(SkillVersion)
  84. .where(SkillVersion.skill_id == skill_id)
  85. .order_by(SkillVersion.version_no.desc())
  86. )
  87. return list(self.db.scalars(stmt))
  88. def get_by_id(self, *, skill_version_id: str) -> SkillVersion | None:
  89. stmt = (
  90. select(SkillVersion)
  91. .where(SkillVersion.id == skill_version_id)
  92. )
  93. return self.db.scalar(stmt)
  94. def get_latest_published(self, *, skill_id: str) -> SkillVersion | None:
  95. stmt = (
  96. select(SkillVersion)
  97. .where(SkillVersion.skill_id == skill_id)
  98. .where(SkillVersion.status == "published")
  99. .order_by(SkillVersion.version_no.desc())
  100. .limit(1)
  101. )
  102. return self.db.scalar(stmt)
  103. def _next_version_no(self, skill_id: str) -> int:
  104. stmt = select(func.max(SkillVersion.version_no)).where(SkillVersion.skill_id == skill_id)
  105. return (self.db.scalar(stmt) or 0) + 1
  106. class SkillInstallationRepository:
  107. def __init__(self, db: Session) -> None:
  108. self.db = db
  109. def create(
  110. self,
  111. *,
  112. skill_id: str,
  113. skill_version_id: str,
  114. install_scope: str,
  115. scope_id: str,
  116. config_json: dict[str, JSONValue],
  117. installed_by: str | None) -> SkillInstallation:
  118. entity = SkillInstallation(
  119. skill_id=skill_id,
  120. skill_version_id=skill_version_id,
  121. install_scope=install_scope,
  122. scope_id=scope_id,
  123. config_json=config_json,
  124. status="installed",
  125. installed_by=installed_by,
  126. installed_time=datetime.utcnow())
  127. self.db.add(entity)
  128. self.db.commit()
  129. self.db.refresh(entity)
  130. return entity
  131. def list_by_scope(
  132. self,
  133. *,
  134. install_scope: str | None = None,
  135. scope_id: str | None = None) -> list[SkillInstallation]:
  136. stmt = select(SkillInstallation)
  137. if install_scope is not None:
  138. stmt = stmt.where(SkillInstallation.install_scope == install_scope)
  139. if scope_id is not None:
  140. stmt = stmt.where(SkillInstallation.scope_id == scope_id)
  141. stmt = stmt.order_by(SkillInstallation.created_time.desc())
  142. return list(self.db.scalars(stmt))
  143. def get_by_id(self, *, installation_id: str) -> SkillInstallation | None:
  144. stmt = (
  145. select(SkillInstallation)
  146. .where(SkillInstallation.id == installation_id)
  147. )
  148. return self.db.scalar(stmt)
  149. def update_status(
  150. self,
  151. *,
  152. installation_id: str,
  153. status: SkillInstallStatus) -> SkillInstallation | None:
  154. entity = self.get_by_id(installation_id=installation_id)
  155. if entity is None:
  156. return None
  157. entity.status = status
  158. self.db.commit()
  159. self.db.refresh(entity)
  160. return entity
  161. class SkillRunRepository:
  162. def __init__(self, db: Session) -> None:
  163. self.db = db
  164. def create(
  165. self,
  166. *,
  167. skill_id: str,
  168. skill_version_id: str,
  169. installation_id: str | None,
  170. input_json: dict[str, JSONValue]) -> SkillRun:
  171. entity = SkillRun(
  172. skill_id=skill_id,
  173. skill_version_id=skill_version_id,
  174. installation_id=installation_id,
  175. input_json=input_json,
  176. status="queued")
  177. self.db.add(entity)
  178. self.db.commit()
  179. self.db.refresh(entity)
  180. return entity
  181. def get_by_id(self, *, skill_run_id: str) -> SkillRun | None:
  182. stmt = (
  183. select(SkillRun)
  184. .where(SkillRun.id == skill_run_id)
  185. )
  186. return self.db.scalar(stmt)
  187. def update_status(
  188. self,
  189. *,
  190. skill_run_id: str,
  191. status: SkillRunStatus,
  192. worker_key: str | None = None,
  193. output_json: dict[str, JSONValue] | None = None,
  194. output_text: str | None = None,
  195. error_code: str | None = None,
  196. error_message: str | None = None) -> SkillRun | None:
  197. entity = self.db.get(SkillRun, skill_run_id)
  198. if entity is None:
  199. return None
  200. now = datetime.utcnow()
  201. entity.status = status
  202. entity.worker_key = worker_key
  203. entity.output_json = output_json
  204. entity.output_text = output_text
  205. entity.error_code = error_code
  206. entity.error_message = error_message
  207. if status == "running" and entity.started_time is None:
  208. entity.started_time = now
  209. if status in {"completed", "failed", "cancelled"}:
  210. entity.finished_time = now
  211. self.db.commit()
  212. self.db.refresh(entity)
  213. return entity