repositories.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. from datetime import datetime
  2. from sqlalchemy import or_, select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import ExecutionLog, NodeArtifact, NodeRun, TraceSpan, WorkflowRun
  5. from core_domain import NodeRunStatus, WorkflowRunStatus
  6. from core_shared import JSONValue
  7. class WorkflowRunRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. app_id: str,
  15. app_version_id: str,
  16. workflow_id: str,
  17. workflow_version_id: str,
  18. session_id: str | None,
  19. parent_run_id: str | None,
  20. root_run_id: str | None,
  21. run_type: str,
  22. trigger_type: str,
  23. priority: int,
  24. ) -> WorkflowRun:
  25. now = datetime.utcnow()
  26. entity = WorkflowRun(
  27. tenant_id=tenant_id,
  28. app_id=app_id,
  29. app_version_id=app_version_id,
  30. workflow_id=workflow_id,
  31. workflow_version_id=workflow_version_id,
  32. session_id=session_id,
  33. parent_run_id=parent_run_id,
  34. root_run_id=root_run_id,
  35. run_type=run_type,
  36. trigger_type=trigger_type,
  37. priority=priority,
  38. status="running",
  39. started_time=now,
  40. )
  41. self.db.add(entity)
  42. self.db.commit()
  43. if entity.root_run_id is None:
  44. entity.root_run_id = entity.id
  45. self.db.commit()
  46. self.db.refresh(entity)
  47. return entity
  48. def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]:
  49. stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id)
  50. if session_id:
  51. stmt = stmt.where(WorkflowRun.session_id == session_id)
  52. stmt = stmt.order_by(WorkflowRun.created_time.desc())
  53. return list(self.db.scalars(stmt))
  54. def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
  55. entity = self.db.get(WorkflowRun, run_id)
  56. if entity is None:
  57. return
  58. entity.current_node_count = current_node_count
  59. self.db.commit()
  60. def get_by_id(self, run_id: str) -> WorkflowRun | None:
  61. return self.db.get(WorkflowRun, run_id)
  62. def update_status(
  63. self,
  64. *,
  65. run_id: str,
  66. status: WorkflowRunStatus,
  67. error_code: str | None = None,
  68. error_message: str | None = None,
  69. ) -> WorkflowRun | None:
  70. entity = self.db.get(WorkflowRun, run_id)
  71. if entity is None:
  72. return None
  73. entity.status = status
  74. entity.error_code = error_code
  75. entity.error_message = error_message
  76. now = datetime.utcnow()
  77. if status == "running" and entity.started_time is None:
  78. entity.started_time = now
  79. if status in {"completed", "failed", "cancelled"}:
  80. entity.finished_time = now
  81. self.db.commit()
  82. self.db.refresh(entity)
  83. return entity
  84. class NodeRunRepository:
  85. def __init__(self, db: Session) -> None:
  86. self.db = db
  87. def create(
  88. self,
  89. *,
  90. tenant_id: str,
  91. run_id: str,
  92. node_id: str,
  93. node_type: str,
  94. status: str,
  95. scheduled_time: datetime | None = None,
  96. timeout_time: datetime | None = None,
  97. parent_node_run_id: str | None = None,
  98. ) -> NodeRun:
  99. now = datetime.utcnow()
  100. entity = NodeRun(
  101. tenant_id=tenant_id,
  102. run_id=run_id,
  103. parent_node_run_id=parent_node_run_id,
  104. node_id=node_id,
  105. node_type=node_type,
  106. status=status,
  107. queued_time=now,
  108. scheduled_time=scheduled_time or now,
  109. timeout_time=timeout_time,
  110. )
  111. self.db.add(entity)
  112. self.db.commit()
  113. self.db.refresh(entity)
  114. return entity
  115. def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]:
  116. stmt = (
  117. select(NodeRun)
  118. .where(NodeRun.tenant_id == tenant_id)
  119. .where(NodeRun.run_id == run_id)
  120. .order_by(NodeRun.created_time.asc())
  121. )
  122. return list(self.db.scalars(stmt))
  123. def list_by_run_and_node_ids(
  124. self,
  125. *,
  126. tenant_id: str,
  127. run_id: str,
  128. node_ids: list[str],
  129. ) -> list[NodeRun]:
  130. if not node_ids:
  131. return []
  132. stmt = (
  133. select(NodeRun)
  134. .where(NodeRun.tenant_id == tenant_id)
  135. .where(NodeRun.run_id == run_id)
  136. .where(NodeRun.node_id.in_(node_ids))
  137. )
  138. return list(self.db.scalars(stmt))
  139. def get_by_id(self, node_run_id: str) -> NodeRun | None:
  140. return self.db.get(NodeRun, node_run_id)
  141. def get_next_queued_by_run(self, *, tenant_id: str, run_id: str) -> NodeRun | None:
  142. stmt = (
  143. select(NodeRun)
  144. .where(NodeRun.tenant_id == tenant_id)
  145. .where(NodeRun.run_id == run_id)
  146. .where(NodeRun.status == "queued")
  147. .where(
  148. or_(
  149. NodeRun.scheduled_time.is_(None),
  150. NodeRun.scheduled_time <= datetime.utcnow(),
  151. )
  152. )
  153. .order_by(NodeRun.created_time.asc())
  154. .limit(1)
  155. )
  156. return self.db.scalar(stmt)
  157. def claim_next_queued(
  158. self,
  159. *,
  160. worker_key: str,
  161. lease_expire_time: datetime,
  162. ) -> NodeRun | None:
  163. stmt = (
  164. select(NodeRun)
  165. .join(WorkflowRun, NodeRun.run_id == WorkflowRun.id)
  166. .where(NodeRun.status == "queued")
  167. .where(
  168. or_(
  169. NodeRun.scheduled_time.is_(None),
  170. NodeRun.scheduled_time <= datetime.utcnow(),
  171. )
  172. )
  173. .order_by(WorkflowRun.priority.desc(), NodeRun.created_time.asc())
  174. .with_for_update(skip_locked=True)
  175. .limit(1)
  176. )
  177. entity = self.db.scalar(stmt)
  178. if entity is None:
  179. return None
  180. now = datetime.utcnow()
  181. entity.status = "running"
  182. entity.worker_key = worker_key
  183. entity.started_time = entity.started_time or now
  184. entity.lease_expire_time = lease_expire_time
  185. self.db.commit()
  186. self.db.refresh(entity)
  187. return entity
  188. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  189. stmt = (
  190. select(NodeRun)
  191. .where(NodeRun.status == "running")
  192. .where(NodeRun.lease_expire_time.is_not(None))
  193. .where(NodeRun.lease_expire_time <= now_time)
  194. .order_by(NodeRun.lease_expire_time.asc())
  195. .limit(max_items)
  196. )
  197. entities = list(self.db.scalars(stmt))
  198. for entity in entities:
  199. entity.status = "queued"
  200. entity.worker_key = None
  201. entity.lease_expire_time = None
  202. entity.scheduled_time = now_time
  203. entity.queued_time = now_time
  204. entity.started_time = None
  205. entity.finished_time = None
  206. entity.attempt_no += 1
  207. if entities:
  208. self.db.commit()
  209. return len(entities)
  210. def update_status(
  211. self,
  212. *,
  213. node_run_id: str,
  214. status: NodeRunStatus,
  215. worker_key: str | None = None,
  216. error_code: str | None = None,
  217. error_message: str | None = None,
  218. output_text: str | None = None,
  219. output_json: dict[str, JSONValue] | None = None,
  220. ) -> NodeRun | None:
  221. entity = self.db.get(NodeRun, node_run_id)
  222. if entity is None:
  223. return None
  224. entity.status = status
  225. entity.worker_key = worker_key
  226. entity.error_code = error_code
  227. entity.error_message = error_message
  228. entity.output_text = output_text
  229. entity.output_json = output_json
  230. now = datetime.utcnow()
  231. if status == "running" and entity.started_time is None:
  232. entity.started_time = now
  233. if status in {"completed", "failed", "skipped"}:
  234. entity.finished_time = now
  235. entity.lease_expire_time = None
  236. self.db.commit()
  237. self.db.refresh(entity)
  238. return entity
  239. def requeue_for_retry(
  240. self,
  241. *,
  242. node_run_id: str,
  243. scheduled_time: datetime,
  244. timeout_time: datetime | None,
  245. error_code: str | None,
  246. error_message: str | None,
  247. output_text: str | None,
  248. output_json: dict[str, JSONValue] | None,
  249. ) -> NodeRun | None:
  250. entity = self.db.get(NodeRun, node_run_id)
  251. if entity is None:
  252. return None
  253. entity.status = "queued"
  254. entity.attempt_no += 1
  255. entity.worker_key = None
  256. entity.lease_expire_time = None
  257. entity.scheduled_time = scheduled_time
  258. entity.timeout_time = timeout_time
  259. entity.queued_time = datetime.utcnow()
  260. entity.started_time = None
  261. entity.finished_time = None
  262. entity.error_code = error_code
  263. entity.error_message = error_message
  264. entity.output_text = output_text
  265. entity.output_json = output_json
  266. self.db.commit()
  267. self.db.refresh(entity)
  268. return entity
  269. class ExecutionLogRepository:
  270. def __init__(self, db: Session) -> None:
  271. self.db = db
  272. def create(
  273. self,
  274. *,
  275. tenant_id: str,
  276. run_id: str,
  277. node_run_id: str | None,
  278. event_type: str,
  279. level: str,
  280. message: str,
  281. detail_json: dict[str, JSONValue] | None,
  282. ) -> ExecutionLog:
  283. entity = ExecutionLog(
  284. tenant_id=tenant_id,
  285. run_id=run_id,
  286. node_run_id=node_run_id,
  287. event_type=event_type,
  288. level=level,
  289. message=message,
  290. detail_json=detail_json,
  291. )
  292. self.db.add(entity)
  293. self.db.commit()
  294. self.db.refresh(entity)
  295. return entity
  296. def list_by_scope(
  297. self,
  298. *,
  299. tenant_id: str,
  300. run_id: str | None = None,
  301. node_run_id: str | None = None,
  302. ) -> list[ExecutionLog]:
  303. stmt = select(ExecutionLog).where(ExecutionLog.tenant_id == tenant_id)
  304. if run_id is not None:
  305. stmt = stmt.where(ExecutionLog.run_id == run_id)
  306. if node_run_id is not None:
  307. stmt = stmt.where(ExecutionLog.node_run_id == node_run_id)
  308. stmt = stmt.order_by(ExecutionLog.created_time.asc())
  309. return list(self.db.scalars(stmt))
  310. class NodeArtifactRepository:
  311. def __init__(self, db: Session) -> None:
  312. self.db = db
  313. def create(
  314. self,
  315. *,
  316. tenant_id: str,
  317. run_id: str,
  318. node_run_id: str,
  319. node_id: str,
  320. artifact_type: str,
  321. name: str,
  322. mime_type: str | None,
  323. content_text: str | None,
  324. content_json: dict[str, JSONValue] | None,
  325. storage_uri: str | None = None,
  326. size_bytes: int | None = None,
  327. ) -> NodeArtifact:
  328. entity = NodeArtifact(
  329. tenant_id=tenant_id,
  330. run_id=run_id,
  331. node_run_id=node_run_id,
  332. node_id=node_id,
  333. artifact_type=artifact_type,
  334. name=name,
  335. mime_type=mime_type,
  336. content_text=content_text,
  337. content_json=content_json,
  338. storage_uri=storage_uri,
  339. size_bytes=size_bytes,
  340. )
  341. self.db.add(entity)
  342. self.db.commit()
  343. self.db.refresh(entity)
  344. return entity
  345. def list_by_scope(
  346. self,
  347. *,
  348. tenant_id: str,
  349. run_id: str | None = None,
  350. node_run_id: str | None = None,
  351. artifact_type: str | None = None,
  352. ) -> list[NodeArtifact]:
  353. stmt = select(NodeArtifact).where(NodeArtifact.tenant_id == tenant_id)
  354. if run_id is not None:
  355. stmt = stmt.where(NodeArtifact.run_id == run_id)
  356. if node_run_id is not None:
  357. stmt = stmt.where(NodeArtifact.node_run_id == node_run_id)
  358. if artifact_type is not None:
  359. stmt = stmt.where(NodeArtifact.artifact_type == artifact_type)
  360. stmt = stmt.order_by(NodeArtifact.created_time.asc())
  361. return list(self.db.scalars(stmt))
  362. class TraceSpanRepository:
  363. def __init__(self, db: Session) -> None:
  364. self.db = db
  365. def start(
  366. self,
  367. *,
  368. tenant_id: str,
  369. run_id: str,
  370. node_run_id: str | None,
  371. parent_span_id: str | None,
  372. span_type: str,
  373. name: str,
  374. attributes_json: dict[str, JSONValue] | None = None,
  375. ) -> TraceSpan:
  376. entity = TraceSpan(
  377. tenant_id=tenant_id,
  378. run_id=run_id,
  379. node_run_id=node_run_id,
  380. parent_span_id=parent_span_id,
  381. span_type=span_type,
  382. name=name,
  383. status="running",
  384. started_time=datetime.utcnow(),
  385. attributes_json=attributes_json,
  386. )
  387. self.db.add(entity)
  388. self.db.commit()
  389. self.db.refresh(entity)
  390. return entity
  391. def finish(
  392. self,
  393. *,
  394. span_id: str,
  395. status: str,
  396. error_code: str | None = None,
  397. error_message: str | None = None,
  398. attributes_json: dict[str, JSONValue] | None = None,
  399. ) -> TraceSpan | None:
  400. entity = self.db.get(TraceSpan, span_id)
  401. if entity is None:
  402. return None
  403. ended_time = datetime.utcnow()
  404. entity.status = status
  405. entity.ended_time = ended_time
  406. entity.duration_ms = int((ended_time - entity.started_time).total_seconds() * 1000)
  407. entity.error_code = error_code
  408. entity.error_message = error_message
  409. if attributes_json is not None:
  410. entity.attributes_json = {
  411. **(entity.attributes_json or {}),
  412. **attributes_json,
  413. }
  414. self.db.commit()
  415. self.db.refresh(entity)
  416. return entity
  417. def list_by_scope(
  418. self,
  419. *,
  420. tenant_id: str,
  421. run_id: str | None = None,
  422. node_run_id: str | None = None,
  423. span_type: str | None = None,
  424. ) -> list[TraceSpan]:
  425. stmt = select(TraceSpan).where(TraceSpan.tenant_id == tenant_id)
  426. if run_id is not None:
  427. stmt = stmt.where(TraceSpan.run_id == run_id)
  428. if node_run_id is not None:
  429. stmt = stmt.where(TraceSpan.node_run_id == node_run_id)
  430. if span_type is not None:
  431. stmt = stmt.where(TraceSpan.span_type == span_type)
  432. stmt = stmt.order_by(TraceSpan.started_time.asc())
  433. return list(self.db.scalars(stmt))