repositories.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import ExecutionLog, NodeArtifact, NodeRun, TraceSpan, WorkflowRun
  5. from core_domain import NodeRunStatus, WorkflowRunStatus
  6. from core_shared import JSONValue
  7. class WorkflowRunRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. app_id: str,
  15. app_version_id: str,
  16. workflow_id: str,
  17. workflow_version_id: str,
  18. session_id: str | None,
  19. parent_run_id: str | None,
  20. root_run_id: str | None,
  21. run_type: str,
  22. trigger_type: str,
  23. priority: int,
  24. ) -> WorkflowRun:
  25. now = datetime.utcnow()
  26. entity = WorkflowRun(
  27. tenant_id=tenant_id,
  28. app_id=app_id,
  29. app_version_id=app_version_id,
  30. workflow_id=workflow_id,
  31. workflow_version_id=workflow_version_id,
  32. session_id=session_id,
  33. parent_run_id=parent_run_id,
  34. root_run_id=root_run_id,
  35. run_type=run_type,
  36. trigger_type=trigger_type,
  37. priority=priority,
  38. status="running",
  39. started_time=now,
  40. )
  41. self.db.add(entity)
  42. self.db.commit()
  43. if entity.root_run_id is None:
  44. entity.root_run_id = entity.id
  45. self.db.commit()
  46. self.db.refresh(entity)
  47. return entity
  48. def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]:
  49. stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id)
  50. if session_id:
  51. stmt = stmt.where(WorkflowRun.session_id == session_id)
  52. stmt = stmt.order_by(WorkflowRun.created_time.desc())
  53. return list(self.db.scalars(stmt))
  54. def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
  55. entity = self.db.get(WorkflowRun, run_id)
  56. if entity is None:
  57. return
  58. entity.current_node_count = current_node_count
  59. self.db.commit()
  60. def get_by_id(self, run_id: str) -> WorkflowRun | None:
  61. return self.db.get(WorkflowRun, run_id)
  62. def update_status(
  63. self,
  64. *,
  65. run_id: str,
  66. status: WorkflowRunStatus,
  67. error_code: str | None = None,
  68. error_message: str | None = None,
  69. ) -> WorkflowRun | None:
  70. entity = self.db.get(WorkflowRun, run_id)
  71. if entity is None:
  72. return None
  73. entity.status = status
  74. entity.error_code = error_code
  75. entity.error_message = error_message
  76. now = datetime.utcnow()
  77. if status == "running" and entity.started_time is None:
  78. entity.started_time = now
  79. if status in {"completed", "failed", "cancelled"}:
  80. entity.finished_time = now
  81. self.db.commit()
  82. self.db.refresh(entity)
  83. return entity
  84. class NodeRunRepository:
  85. def __init__(self, db: Session) -> None:
  86. self.db = db
  87. def create(
  88. self,
  89. *,
  90. tenant_id: str,
  91. run_id: str,
  92. node_id: str,
  93. node_type: str,
  94. status: str,
  95. ) -> NodeRun:
  96. now = datetime.utcnow()
  97. entity = NodeRun(
  98. tenant_id=tenant_id,
  99. run_id=run_id,
  100. node_id=node_id,
  101. node_type=node_type,
  102. status=status,
  103. queued_time=now,
  104. )
  105. self.db.add(entity)
  106. self.db.commit()
  107. self.db.refresh(entity)
  108. return entity
  109. def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]:
  110. stmt = (
  111. select(NodeRun)
  112. .where(NodeRun.tenant_id == tenant_id)
  113. .where(NodeRun.run_id == run_id)
  114. .order_by(NodeRun.created_time.asc())
  115. )
  116. return list(self.db.scalars(stmt))
  117. def list_by_run_and_node_ids(
  118. self,
  119. *,
  120. tenant_id: str,
  121. run_id: str,
  122. node_ids: list[str],
  123. ) -> list[NodeRun]:
  124. if not node_ids:
  125. return []
  126. stmt = (
  127. select(NodeRun)
  128. .where(NodeRun.tenant_id == tenant_id)
  129. .where(NodeRun.run_id == run_id)
  130. .where(NodeRun.node_id.in_(node_ids))
  131. )
  132. return list(self.db.scalars(stmt))
  133. def get_by_id(self, node_run_id: str) -> NodeRun | None:
  134. return self.db.get(NodeRun, node_run_id)
  135. def get_next_queued_by_run(self, *, tenant_id: str, run_id: str) -> NodeRun | None:
  136. stmt = (
  137. select(NodeRun)
  138. .where(NodeRun.tenant_id == tenant_id)
  139. .where(NodeRun.run_id == run_id)
  140. .where(NodeRun.status == "queued")
  141. .order_by(NodeRun.created_time.asc())
  142. .limit(1)
  143. )
  144. return self.db.scalar(stmt)
  145. def claim_next_queued(
  146. self,
  147. *,
  148. worker_key: str,
  149. lease_expire_time: datetime,
  150. ) -> NodeRun | None:
  151. stmt = (
  152. select(NodeRun)
  153. .join(WorkflowRun, NodeRun.run_id == WorkflowRun.id)
  154. .where(NodeRun.status == "queued")
  155. .order_by(WorkflowRun.priority.desc(), NodeRun.created_time.asc())
  156. .with_for_update(skip_locked=True)
  157. .limit(1)
  158. )
  159. entity = self.db.scalar(stmt)
  160. if entity is None:
  161. return None
  162. now = datetime.utcnow()
  163. entity.status = "running"
  164. entity.worker_key = worker_key
  165. entity.started_time = entity.started_time or now
  166. entity.lease_expire_time = lease_expire_time
  167. self.db.commit()
  168. self.db.refresh(entity)
  169. return entity
  170. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  171. stmt = (
  172. select(NodeRun)
  173. .where(NodeRun.status == "running")
  174. .where(NodeRun.lease_expire_time.is_not(None))
  175. .where(NodeRun.lease_expire_time <= now_time)
  176. .order_by(NodeRun.lease_expire_time.asc())
  177. .limit(max_items)
  178. )
  179. entities = list(self.db.scalars(stmt))
  180. for entity in entities:
  181. entity.status = "queued"
  182. entity.worker_key = None
  183. entity.lease_expire_time = None
  184. entity.queued_time = now_time
  185. entity.started_time = None
  186. entity.finished_time = None
  187. entity.attempt_no += 1
  188. if entities:
  189. self.db.commit()
  190. return len(entities)
  191. def update_status(
  192. self,
  193. *,
  194. node_run_id: str,
  195. status: NodeRunStatus,
  196. worker_key: str | None = None,
  197. error_code: str | None = None,
  198. error_message: str | None = None,
  199. output_text: str | None = None,
  200. output_json: dict[str, JSONValue] | None = None,
  201. ) -> NodeRun | None:
  202. entity = self.db.get(NodeRun, node_run_id)
  203. if entity is None:
  204. return None
  205. entity.status = status
  206. entity.worker_key = worker_key
  207. entity.error_code = error_code
  208. entity.error_message = error_message
  209. entity.output_text = output_text
  210. entity.output_json = output_json
  211. now = datetime.utcnow()
  212. if status == "running" and entity.started_time is None:
  213. entity.started_time = now
  214. if status in {"completed", "failed", "skipped"}:
  215. entity.finished_time = now
  216. entity.lease_expire_time = None
  217. self.db.commit()
  218. self.db.refresh(entity)
  219. return entity
  220. class ExecutionLogRepository:
  221. def __init__(self, db: Session) -> None:
  222. self.db = db
  223. def create(
  224. self,
  225. *,
  226. tenant_id: str,
  227. run_id: str,
  228. node_run_id: str | None,
  229. event_type: str,
  230. level: str,
  231. message: str,
  232. detail_json: dict[str, JSONValue] | None,
  233. ) -> ExecutionLog:
  234. entity = ExecutionLog(
  235. tenant_id=tenant_id,
  236. run_id=run_id,
  237. node_run_id=node_run_id,
  238. event_type=event_type,
  239. level=level,
  240. message=message,
  241. detail_json=detail_json,
  242. )
  243. self.db.add(entity)
  244. self.db.commit()
  245. self.db.refresh(entity)
  246. return entity
  247. def list_by_scope(
  248. self,
  249. *,
  250. tenant_id: str,
  251. run_id: str | None = None,
  252. node_run_id: str | None = None,
  253. ) -> list[ExecutionLog]:
  254. stmt = select(ExecutionLog).where(ExecutionLog.tenant_id == tenant_id)
  255. if run_id is not None:
  256. stmt = stmt.where(ExecutionLog.run_id == run_id)
  257. if node_run_id is not None:
  258. stmt = stmt.where(ExecutionLog.node_run_id == node_run_id)
  259. stmt = stmt.order_by(ExecutionLog.created_time.asc())
  260. return list(self.db.scalars(stmt))
  261. class NodeArtifactRepository:
  262. def __init__(self, db: Session) -> None:
  263. self.db = db
  264. def create(
  265. self,
  266. *,
  267. tenant_id: str,
  268. run_id: str,
  269. node_run_id: str,
  270. node_id: str,
  271. artifact_type: str,
  272. name: str,
  273. mime_type: str | None,
  274. content_text: str | None,
  275. content_json: dict[str, JSONValue] | None,
  276. storage_uri: str | None = None,
  277. size_bytes: int | None = None,
  278. ) -> NodeArtifact:
  279. entity = NodeArtifact(
  280. tenant_id=tenant_id,
  281. run_id=run_id,
  282. node_run_id=node_run_id,
  283. node_id=node_id,
  284. artifact_type=artifact_type,
  285. name=name,
  286. mime_type=mime_type,
  287. content_text=content_text,
  288. content_json=content_json,
  289. storage_uri=storage_uri,
  290. size_bytes=size_bytes,
  291. )
  292. self.db.add(entity)
  293. self.db.commit()
  294. self.db.refresh(entity)
  295. return entity
  296. def list_by_scope(
  297. self,
  298. *,
  299. tenant_id: str,
  300. run_id: str | None = None,
  301. node_run_id: str | None = None,
  302. artifact_type: str | None = None,
  303. ) -> list[NodeArtifact]:
  304. stmt = select(NodeArtifact).where(NodeArtifact.tenant_id == tenant_id)
  305. if run_id is not None:
  306. stmt = stmt.where(NodeArtifact.run_id == run_id)
  307. if node_run_id is not None:
  308. stmt = stmt.where(NodeArtifact.node_run_id == node_run_id)
  309. if artifact_type is not None:
  310. stmt = stmt.where(NodeArtifact.artifact_type == artifact_type)
  311. stmt = stmt.order_by(NodeArtifact.created_time.asc())
  312. return list(self.db.scalars(stmt))
  313. class TraceSpanRepository:
  314. def __init__(self, db: Session) -> None:
  315. self.db = db
  316. def start(
  317. self,
  318. *,
  319. tenant_id: str,
  320. run_id: str,
  321. node_run_id: str | None,
  322. parent_span_id: str | None,
  323. span_type: str,
  324. name: str,
  325. attributes_json: dict[str, JSONValue] | None = None,
  326. ) -> TraceSpan:
  327. entity = TraceSpan(
  328. tenant_id=tenant_id,
  329. run_id=run_id,
  330. node_run_id=node_run_id,
  331. parent_span_id=parent_span_id,
  332. span_type=span_type,
  333. name=name,
  334. status="running",
  335. started_time=datetime.utcnow(),
  336. attributes_json=attributes_json,
  337. )
  338. self.db.add(entity)
  339. self.db.commit()
  340. self.db.refresh(entity)
  341. return entity
  342. def finish(
  343. self,
  344. *,
  345. span_id: str,
  346. status: str,
  347. error_code: str | None = None,
  348. error_message: str | None = None,
  349. attributes_json: dict[str, JSONValue] | None = None,
  350. ) -> TraceSpan | None:
  351. entity = self.db.get(TraceSpan, span_id)
  352. if entity is None:
  353. return None
  354. ended_time = datetime.utcnow()
  355. entity.status = status
  356. entity.ended_time = ended_time
  357. entity.duration_ms = int((ended_time - entity.started_time).total_seconds() * 1000)
  358. entity.error_code = error_code
  359. entity.error_message = error_message
  360. if attributes_json is not None:
  361. entity.attributes_json = {
  362. **(entity.attributes_json or {}),
  363. **attributes_json,
  364. }
  365. self.db.commit()
  366. self.db.refresh(entity)
  367. return entity
  368. def list_by_scope(
  369. self,
  370. *,
  371. tenant_id: str,
  372. run_id: str | None = None,
  373. node_run_id: str | None = None,
  374. span_type: str | None = None,
  375. ) -> list[TraceSpan]:
  376. stmt = select(TraceSpan).where(TraceSpan.tenant_id == tenant_id)
  377. if run_id is not None:
  378. stmt = stmt.where(TraceSpan.run_id == run_id)
  379. if node_run_id is not None:
  380. stmt = stmt.where(TraceSpan.node_run_id == node_run_id)
  381. if span_type is not None:
  382. stmt = stmt.where(TraceSpan.span_type == span_type)
  383. stmt = stmt.order_by(TraceSpan.started_time.asc())
  384. return list(self.db.scalars(stmt))