repositories.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. from datetime import datetime
  2. from core_domain import NodeRunStatus, WorkflowRunStatus
  3. from core_shared import JSONValue
  4. from sqlalchemy import or_, select
  5. from sqlalchemy.orm import Session
  6. from app.db.models import ExecutionLog, NodeArtifact, NodeRun, TraceSpan, WorkflowRun
  7. class WorkflowRunRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. app_id: str,
  14. app_config_id: str,
  15. workflow_id: str,
  16. workflow_config_id: str,
  17. session_id: str | None,
  18. parent_run_id: str | None,
  19. root_run_id: str | None,
  20. run_type: str,
  21. trigger_type: str,
  22. priority: int) -> WorkflowRun:
  23. now = datetime.utcnow()
  24. entity = WorkflowRun(
  25. app_id=app_id,
  26. app_config_id=app_config_id,
  27. workflow_id=workflow_id,
  28. workflow_config_id=workflow_config_id,
  29. session_id=session_id,
  30. parent_run_id=parent_run_id,
  31. root_run_id=root_run_id,
  32. run_type=run_type,
  33. trigger_type=trigger_type,
  34. priority=priority,
  35. status="running",
  36. started_time=now)
  37. self.db.add(entity)
  38. self.db.commit()
  39. if entity.root_run_id is None:
  40. entity.root_run_id = entity.id
  41. self.db.commit()
  42. self.db.refresh(entity)
  43. return entity
  44. def list_by_scope(
  45. self,
  46. *,
  47. session_id: str | None = None,
  48. limit: int = 50) -> list[WorkflowRun]:
  49. stmt = select(WorkflowRun)
  50. if session_id:
  51. stmt = stmt.where(WorkflowRun.session_id == session_id)
  52. stmt = stmt.order_by(WorkflowRun.created_time.desc()).limit(limit)
  53. return list(self.db.scalars(stmt))
  54. def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
  55. entity = self.db.get(WorkflowRun, run_id)
  56. if entity is None:
  57. return
  58. entity.current_node_count = current_node_count
  59. self.db.commit()
  60. def get_by_id(self, run_id: str) -> WorkflowRun | None:
  61. return self.db.get(WorkflowRun, run_id)
  62. def update_status(
  63. self,
  64. *,
  65. run_id: str,
  66. status: WorkflowRunStatus,
  67. error_code: str | None = None,
  68. error_message: str | None = None) -> WorkflowRun | None:
  69. entity = self.db.get(WorkflowRun, run_id)
  70. if entity is None:
  71. return None
  72. entity.status = status
  73. entity.error_code = error_code
  74. entity.error_message = error_message
  75. now = datetime.utcnow()
  76. if status == "running" and entity.started_time is None:
  77. entity.started_time = now
  78. if status in {"completed", "failed", "cancelled"}:
  79. entity.finished_time = now
  80. self.db.commit()
  81. self.db.refresh(entity)
  82. return entity
  83. class NodeRunRepository:
  84. def __init__(self, db: Session) -> None:
  85. self.db = db
  86. def create(
  87. self,
  88. *,
  89. run_id: str,
  90. node_id: str,
  91. node_type: str,
  92. status: str,
  93. scheduled_time: datetime | None = None,
  94. timeout_time: datetime | None = None,
  95. parent_node_run_id: str | None = None) -> NodeRun:
  96. now = datetime.utcnow()
  97. entity = NodeRun(
  98. run_id=run_id,
  99. parent_node_run_id=parent_node_run_id,
  100. node_id=node_id,
  101. node_type=node_type,
  102. status=status,
  103. queued_time=now,
  104. scheduled_time=scheduled_time or now,
  105. timeout_time=timeout_time)
  106. self.db.add(entity)
  107. self.db.commit()
  108. self.db.refresh(entity)
  109. return entity
  110. def list_by_run(self, *, run_id: str) -> list[NodeRun]:
  111. stmt = (
  112. select(NodeRun)
  113. .where(NodeRun.run_id == run_id)
  114. .order_by(NodeRun.created_time.asc())
  115. )
  116. return list(self.db.scalars(stmt))
  117. def list_by_run_and_node_ids(
  118. self,
  119. *,
  120. run_id: str,
  121. node_ids: list[str]) -> list[NodeRun]:
  122. if not node_ids:
  123. return []
  124. stmt = (
  125. select(NodeRun)
  126. .where(NodeRun.run_id == run_id)
  127. .where(NodeRun.node_id.in_(node_ids))
  128. )
  129. return list(self.db.scalars(stmt))
  130. def get_by_id(self, node_run_id: str) -> NodeRun | None:
  131. return self.db.get(NodeRun, node_run_id)
  132. def get_next_queued_by_run(self, *, run_id: str) -> NodeRun | None:
  133. stmt = (
  134. select(NodeRun)
  135. .where(NodeRun.run_id == run_id)
  136. .where(NodeRun.status == "queued")
  137. .where(
  138. or_(
  139. NodeRun.scheduled_time.is_(None),
  140. NodeRun.scheduled_time <= datetime.utcnow())
  141. )
  142. .order_by(NodeRun.created_time.asc())
  143. .limit(1)
  144. )
  145. return self.db.scalar(stmt)
  146. def claim_next_queued(
  147. self,
  148. *,
  149. worker_key: str,
  150. lease_expire_time: datetime) -> NodeRun | None:
  151. stmt = (
  152. select(NodeRun)
  153. .join(WorkflowRun, NodeRun.run_id == WorkflowRun.id)
  154. .where(NodeRun.status == "queued")
  155. .where(
  156. or_(
  157. NodeRun.scheduled_time.is_(None),
  158. NodeRun.scheduled_time <= datetime.utcnow())
  159. )
  160. .order_by(WorkflowRun.priority.desc(), NodeRun.created_time.asc())
  161. .with_for_update(skip_locked=True)
  162. .limit(1)
  163. )
  164. entity = self.db.scalar(stmt)
  165. if entity is None:
  166. return None
  167. now = datetime.utcnow()
  168. entity.status = "running"
  169. entity.worker_key = worker_key
  170. entity.started_time = entity.started_time or now
  171. entity.lease_expire_time = lease_expire_time
  172. self.db.commit()
  173. self.db.refresh(entity)
  174. return entity
  175. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  176. stmt = (
  177. select(NodeRun)
  178. .where(NodeRun.status == "running")
  179. .where(NodeRun.lease_expire_time.is_not(None))
  180. .where(NodeRun.lease_expire_time <= now_time)
  181. .order_by(NodeRun.lease_expire_time.asc())
  182. .limit(max_items)
  183. )
  184. entities = list(self.db.scalars(stmt))
  185. for entity in entities:
  186. entity.status = "queued"
  187. entity.worker_key = None
  188. entity.lease_expire_time = None
  189. entity.scheduled_time = now_time
  190. entity.queued_time = now_time
  191. entity.started_time = None
  192. entity.finished_time = None
  193. entity.attempt_no += 1
  194. if entities:
  195. self.db.commit()
  196. return len(entities)
  197. def update_status(
  198. self,
  199. *,
  200. node_run_id: str,
  201. status: NodeRunStatus,
  202. worker_key: str | None = None,
  203. error_code: str | None = None,
  204. error_message: str | None = None,
  205. output_text: str | None = None,
  206. output_json: dict[str, JSONValue] | None = None) -> NodeRun | None:
  207. entity = self.db.get(NodeRun, node_run_id)
  208. if entity is None:
  209. return None
  210. entity.status = status
  211. entity.worker_key = worker_key
  212. entity.error_code = error_code
  213. entity.error_message = error_message
  214. entity.output_text = output_text
  215. entity.output_json = output_json
  216. now = datetime.utcnow()
  217. if status == "running" and entity.started_time is None:
  218. entity.started_time = now
  219. if status != "running":
  220. entity.lease_expire_time = None
  221. if status in {"completed", "failed", "skipped"}:
  222. entity.finished_time = now
  223. self.db.commit()
  224. self.db.refresh(entity)
  225. return entity
  226. def requeue_for_retry(
  227. self,
  228. *,
  229. node_run_id: str,
  230. scheduled_time: datetime,
  231. timeout_time: datetime | None,
  232. error_code: str | None,
  233. error_message: str | None,
  234. output_text: str | None,
  235. output_json: dict[str, JSONValue] | None) -> NodeRun | None:
  236. entity = self.db.get(NodeRun, node_run_id)
  237. if entity is None:
  238. return None
  239. entity.status = "queued"
  240. entity.attempt_no += 1
  241. entity.worker_key = None
  242. entity.lease_expire_time = None
  243. entity.scheduled_time = scheduled_time
  244. entity.timeout_time = timeout_time
  245. entity.queued_time = datetime.utcnow()
  246. entity.started_time = None
  247. entity.finished_time = None
  248. entity.error_code = error_code
  249. entity.error_message = error_message
  250. entity.output_text = output_text
  251. entity.output_json = output_json
  252. self.db.commit()
  253. self.db.refresh(entity)
  254. return entity
  255. class ExecutionLogRepository:
  256. def __init__(self, db: Session) -> None:
  257. self.db = db
  258. def create(
  259. self,
  260. *,
  261. run_id: str,
  262. node_run_id: str | None,
  263. event_type: str,
  264. level: str,
  265. message: str,
  266. detail_json: dict[str, JSONValue] | None) -> ExecutionLog:
  267. entity = ExecutionLog(
  268. run_id=run_id,
  269. node_run_id=node_run_id,
  270. event_type=event_type,
  271. level=level,
  272. message=message,
  273. detail_json=detail_json)
  274. self.db.add(entity)
  275. self.db.commit()
  276. self.db.refresh(entity)
  277. return entity
  278. def list_by_scope(
  279. self,
  280. *,
  281. run_id: str | None = None,
  282. node_run_id: str | None = None) -> list[ExecutionLog]:
  283. stmt = select(ExecutionLog)
  284. if run_id is not None:
  285. stmt = stmt.where(ExecutionLog.run_id == run_id)
  286. if node_run_id is not None:
  287. stmt = stmt.where(ExecutionLog.node_run_id == node_run_id)
  288. stmt = stmt.order_by(ExecutionLog.created_time.asc())
  289. return list(self.db.scalars(stmt))
  290. class NodeArtifactRepository:
  291. def __init__(self, db: Session) -> None:
  292. self.db = db
  293. def create(
  294. self,
  295. *,
  296. run_id: str,
  297. node_run_id: str,
  298. node_id: str,
  299. artifact_type: str,
  300. name: str,
  301. mime_type: str | None,
  302. content_text: str | None,
  303. content_json: dict[str, JSONValue] | None,
  304. storage_uri: str | None = None,
  305. size_bytes: int | None = None) -> NodeArtifact:
  306. entity = NodeArtifact(
  307. run_id=run_id,
  308. node_run_id=node_run_id,
  309. node_id=node_id,
  310. artifact_type=artifact_type,
  311. name=name,
  312. mime_type=mime_type,
  313. content_text=content_text,
  314. content_json=content_json,
  315. storage_uri=storage_uri,
  316. size_bytes=size_bytes)
  317. self.db.add(entity)
  318. self.db.commit()
  319. self.db.refresh(entity)
  320. return entity
  321. def list_by_scope(
  322. self,
  323. *,
  324. run_id: str | None = None,
  325. node_run_id: str | None = None,
  326. artifact_type: str | None = None) -> list[NodeArtifact]:
  327. stmt = select(NodeArtifact)
  328. if run_id is not None:
  329. stmt = stmt.where(NodeArtifact.run_id == run_id)
  330. if node_run_id is not None:
  331. stmt = stmt.where(NodeArtifact.node_run_id == node_run_id)
  332. if artifact_type is not None:
  333. stmt = stmt.where(NodeArtifact.artifact_type == artifact_type)
  334. stmt = stmt.order_by(NodeArtifact.created_time.asc())
  335. return list(self.db.scalars(stmt))
  336. class TraceSpanRepository:
  337. def __init__(self, db: Session) -> None:
  338. self.db = db
  339. def start(
  340. self,
  341. *,
  342. run_id: str,
  343. node_run_id: str | None,
  344. parent_span_id: str | None,
  345. span_type: str,
  346. name: str,
  347. attributes_json: dict[str, JSONValue] | None = None) -> TraceSpan:
  348. entity = TraceSpan(
  349. run_id=run_id,
  350. node_run_id=node_run_id,
  351. parent_span_id=parent_span_id,
  352. span_type=span_type,
  353. name=name,
  354. status="running",
  355. started_time=datetime.utcnow(),
  356. attributes_json=attributes_json)
  357. self.db.add(entity)
  358. self.db.commit()
  359. self.db.refresh(entity)
  360. return entity
  361. def finish(
  362. self,
  363. *,
  364. span_id: str,
  365. status: str,
  366. error_code: str | None = None,
  367. error_message: str | None = None,
  368. attributes_json: dict[str, JSONValue] | None = None) -> TraceSpan | None:
  369. entity = self.db.get(TraceSpan, span_id)
  370. if entity is None:
  371. return None
  372. ended_time = datetime.utcnow()
  373. entity.status = status
  374. entity.ended_time = ended_time
  375. entity.duration_ms = int((ended_time - entity.started_time).total_seconds() * 1000)
  376. entity.error_code = error_code
  377. entity.error_message = error_message
  378. if attributes_json is not None:
  379. entity.attributes_json = {
  380. **(entity.attributes_json or {}),
  381. **attributes_json,
  382. }
  383. self.db.commit()
  384. self.db.refresh(entity)
  385. return entity
  386. def list_by_scope(
  387. self,
  388. *,
  389. run_id: str | None = None,
  390. node_run_id: str | None = None,
  391. span_type: str | None = None) -> list[TraceSpan]:
  392. stmt = select(TraceSpan)
  393. if run_id is not None:
  394. stmt = stmt.where(TraceSpan.run_id == run_id)
  395. if node_run_id is not None:
  396. stmt = stmt.where(TraceSpan.node_run_id == node_run_id)
  397. if span_type is not None:
  398. stmt = stmt.where(TraceSpan.span_type == span_type)
  399. stmt = stmt.order_by(TraceSpan.started_time.asc())
  400. return list(self.db.scalars(stmt))