repositories.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. from datetime import datetime
  2. from core_domain import NodeRunStatus, WorkflowRunStatus
  3. from core_shared import JSONValue
  4. from sqlalchemy import or_, select
  5. from sqlalchemy.orm import Session
  6. from app.db.models import ExecutionLog, NodeArtifact, NodeRun, TraceSpan, WorkflowRun
  7. class WorkflowRunRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. app_id: str,
  14. app_version_id: str,
  15. workflow_id: str,
  16. workflow_version_id: str,
  17. session_id: str | None,
  18. parent_run_id: str | None,
  19. root_run_id: str | None,
  20. run_type: str,
  21. trigger_type: str,
  22. priority: int) -> WorkflowRun:
  23. now = datetime.utcnow()
  24. entity = WorkflowRun(
  25. app_id=app_id,
  26. app_version_id=app_version_id,
  27. workflow_id=workflow_id,
  28. workflow_version_id=workflow_version_id,
  29. session_id=session_id,
  30. parent_run_id=parent_run_id,
  31. root_run_id=root_run_id,
  32. run_type=run_type,
  33. trigger_type=trigger_type,
  34. priority=priority,
  35. status="running",
  36. started_time=now)
  37. self.db.add(entity)
  38. self.db.commit()
  39. if entity.root_run_id is None:
  40. entity.root_run_id = entity.id
  41. self.db.commit()
  42. self.db.refresh(entity)
  43. return entity
  44. def list_by_scope(self, *, session_id: str | None = None) -> list[WorkflowRun]:
  45. stmt = select(WorkflowRun)
  46. if session_id:
  47. stmt = stmt.where(WorkflowRun.session_id == session_id)
  48. stmt = stmt.order_by(WorkflowRun.created_time.desc())
  49. return list(self.db.scalars(stmt))
  50. def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
  51. entity = self.db.get(WorkflowRun, run_id)
  52. if entity is None:
  53. return
  54. entity.current_node_count = current_node_count
  55. self.db.commit()
  56. def get_by_id(self, run_id: str) -> WorkflowRun | None:
  57. return self.db.get(WorkflowRun, run_id)
  58. def update_status(
  59. self,
  60. *,
  61. run_id: str,
  62. status: WorkflowRunStatus,
  63. error_code: str | None = None,
  64. error_message: str | None = None) -> WorkflowRun | None:
  65. entity = self.db.get(WorkflowRun, run_id)
  66. if entity is None:
  67. return None
  68. entity.status = status
  69. entity.error_code = error_code
  70. entity.error_message = error_message
  71. now = datetime.utcnow()
  72. if status == "running" and entity.started_time is None:
  73. entity.started_time = now
  74. if status in {"completed", "failed", "cancelled"}:
  75. entity.finished_time = now
  76. self.db.commit()
  77. self.db.refresh(entity)
  78. return entity
  79. class NodeRunRepository:
  80. def __init__(self, db: Session) -> None:
  81. self.db = db
  82. def create(
  83. self,
  84. *,
  85. run_id: str,
  86. node_id: str,
  87. node_type: str,
  88. status: str,
  89. scheduled_time: datetime | None = None,
  90. timeout_time: datetime | None = None,
  91. parent_node_run_id: str | None = None) -> NodeRun:
  92. now = datetime.utcnow()
  93. entity = NodeRun(
  94. run_id=run_id,
  95. parent_node_run_id=parent_node_run_id,
  96. node_id=node_id,
  97. node_type=node_type,
  98. status=status,
  99. queued_time=now,
  100. scheduled_time=scheduled_time or now,
  101. timeout_time=timeout_time)
  102. self.db.add(entity)
  103. self.db.commit()
  104. self.db.refresh(entity)
  105. return entity
  106. def list_by_run(self, *, run_id: str) -> list[NodeRun]:
  107. stmt = (
  108. select(NodeRun)
  109. .where(NodeRun.run_id == run_id)
  110. .order_by(NodeRun.created_time.asc())
  111. )
  112. return list(self.db.scalars(stmt))
  113. def list_by_run_and_node_ids(
  114. self,
  115. *,
  116. run_id: str,
  117. node_ids: list[str]) -> list[NodeRun]:
  118. if not node_ids:
  119. return []
  120. stmt = (
  121. select(NodeRun)
  122. .where(NodeRun.run_id == run_id)
  123. .where(NodeRun.node_id.in_(node_ids))
  124. )
  125. return list(self.db.scalars(stmt))
  126. def get_by_id(self, node_run_id: str) -> NodeRun | None:
  127. return self.db.get(NodeRun, node_run_id)
  128. def get_next_queued_by_run(self, *, run_id: str) -> NodeRun | None:
  129. stmt = (
  130. select(NodeRun)
  131. .where(NodeRun.run_id == run_id)
  132. .where(NodeRun.status == "queued")
  133. .where(
  134. or_(
  135. NodeRun.scheduled_time.is_(None),
  136. NodeRun.scheduled_time <= datetime.utcnow())
  137. )
  138. .order_by(NodeRun.created_time.asc())
  139. .limit(1)
  140. )
  141. return self.db.scalar(stmt)
  142. def claim_next_queued(
  143. self,
  144. *,
  145. worker_key: str,
  146. lease_expire_time: datetime) -> NodeRun | None:
  147. stmt = (
  148. select(NodeRun)
  149. .join(WorkflowRun, NodeRun.run_id == WorkflowRun.id)
  150. .where(NodeRun.status == "queued")
  151. .where(
  152. or_(
  153. NodeRun.scheduled_time.is_(None),
  154. NodeRun.scheduled_time <= datetime.utcnow())
  155. )
  156. .order_by(WorkflowRun.priority.desc(), NodeRun.created_time.asc())
  157. .with_for_update(skip_locked=True)
  158. .limit(1)
  159. )
  160. entity = self.db.scalar(stmt)
  161. if entity is None:
  162. return None
  163. now = datetime.utcnow()
  164. entity.status = "running"
  165. entity.worker_key = worker_key
  166. entity.started_time = entity.started_time or now
  167. entity.lease_expire_time = lease_expire_time
  168. self.db.commit()
  169. self.db.refresh(entity)
  170. return entity
  171. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  172. stmt = (
  173. select(NodeRun)
  174. .where(NodeRun.status == "running")
  175. .where(NodeRun.lease_expire_time.is_not(None))
  176. .where(NodeRun.lease_expire_time <= now_time)
  177. .order_by(NodeRun.lease_expire_time.asc())
  178. .limit(max_items)
  179. )
  180. entities = list(self.db.scalars(stmt))
  181. for entity in entities:
  182. entity.status = "queued"
  183. entity.worker_key = None
  184. entity.lease_expire_time = None
  185. entity.scheduled_time = now_time
  186. entity.queued_time = now_time
  187. entity.started_time = None
  188. entity.finished_time = None
  189. entity.attempt_no += 1
  190. if entities:
  191. self.db.commit()
  192. return len(entities)
  193. def update_status(
  194. self,
  195. *,
  196. node_run_id: str,
  197. status: NodeRunStatus,
  198. worker_key: str | None = None,
  199. error_code: str | None = None,
  200. error_message: str | None = None,
  201. output_text: str | None = None,
  202. output_json: dict[str, JSONValue] | None = None) -> NodeRun | None:
  203. entity = self.db.get(NodeRun, node_run_id)
  204. if entity is None:
  205. return None
  206. entity.status = status
  207. entity.worker_key = worker_key
  208. entity.error_code = error_code
  209. entity.error_message = error_message
  210. entity.output_text = output_text
  211. entity.output_json = output_json
  212. now = datetime.utcnow()
  213. if status == "running" and entity.started_time is None:
  214. entity.started_time = now
  215. if status != "running":
  216. entity.lease_expire_time = None
  217. if status in {"completed", "failed", "skipped"}:
  218. entity.finished_time = now
  219. self.db.commit()
  220. self.db.refresh(entity)
  221. return entity
  222. def requeue_for_retry(
  223. self,
  224. *,
  225. node_run_id: str,
  226. scheduled_time: datetime,
  227. timeout_time: datetime | None,
  228. error_code: str | None,
  229. error_message: str | None,
  230. output_text: str | None,
  231. output_json: dict[str, JSONValue] | None) -> NodeRun | None:
  232. entity = self.db.get(NodeRun, node_run_id)
  233. if entity is None:
  234. return None
  235. entity.status = "queued"
  236. entity.attempt_no += 1
  237. entity.worker_key = None
  238. entity.lease_expire_time = None
  239. entity.scheduled_time = scheduled_time
  240. entity.timeout_time = timeout_time
  241. entity.queued_time = datetime.utcnow()
  242. entity.started_time = None
  243. entity.finished_time = None
  244. entity.error_code = error_code
  245. entity.error_message = error_message
  246. entity.output_text = output_text
  247. entity.output_json = output_json
  248. self.db.commit()
  249. self.db.refresh(entity)
  250. return entity
  251. class ExecutionLogRepository:
  252. def __init__(self, db: Session) -> None:
  253. self.db = db
  254. def create(
  255. self,
  256. *,
  257. run_id: str,
  258. node_run_id: str | None,
  259. event_type: str,
  260. level: str,
  261. message: str,
  262. detail_json: dict[str, JSONValue] | None) -> ExecutionLog:
  263. entity = ExecutionLog(
  264. run_id=run_id,
  265. node_run_id=node_run_id,
  266. event_type=event_type,
  267. level=level,
  268. message=message,
  269. detail_json=detail_json)
  270. self.db.add(entity)
  271. self.db.commit()
  272. self.db.refresh(entity)
  273. return entity
  274. def list_by_scope(
  275. self,
  276. *,
  277. run_id: str | None = None,
  278. node_run_id: str | None = None) -> list[ExecutionLog]:
  279. stmt = select(ExecutionLog)
  280. if run_id is not None:
  281. stmt = stmt.where(ExecutionLog.run_id == run_id)
  282. if node_run_id is not None:
  283. stmt = stmt.where(ExecutionLog.node_run_id == node_run_id)
  284. stmt = stmt.order_by(ExecutionLog.created_time.asc())
  285. return list(self.db.scalars(stmt))
  286. class NodeArtifactRepository:
  287. def __init__(self, db: Session) -> None:
  288. self.db = db
  289. def create(
  290. self,
  291. *,
  292. run_id: str,
  293. node_run_id: str,
  294. node_id: str,
  295. artifact_type: str,
  296. name: str,
  297. mime_type: str | None,
  298. content_text: str | None,
  299. content_json: dict[str, JSONValue] | None,
  300. storage_uri: str | None = None,
  301. size_bytes: int | None = None) -> NodeArtifact:
  302. entity = NodeArtifact(
  303. run_id=run_id,
  304. node_run_id=node_run_id,
  305. node_id=node_id,
  306. artifact_type=artifact_type,
  307. name=name,
  308. mime_type=mime_type,
  309. content_text=content_text,
  310. content_json=content_json,
  311. storage_uri=storage_uri,
  312. size_bytes=size_bytes)
  313. self.db.add(entity)
  314. self.db.commit()
  315. self.db.refresh(entity)
  316. return entity
  317. def list_by_scope(
  318. self,
  319. *,
  320. run_id: str | None = None,
  321. node_run_id: str | None = None,
  322. artifact_type: str | None = None) -> list[NodeArtifact]:
  323. stmt = select(NodeArtifact)
  324. if run_id is not None:
  325. stmt = stmt.where(NodeArtifact.run_id == run_id)
  326. if node_run_id is not None:
  327. stmt = stmt.where(NodeArtifact.node_run_id == node_run_id)
  328. if artifact_type is not None:
  329. stmt = stmt.where(NodeArtifact.artifact_type == artifact_type)
  330. stmt = stmt.order_by(NodeArtifact.created_time.asc())
  331. return list(self.db.scalars(stmt))
  332. class TraceSpanRepository:
  333. def __init__(self, db: Session) -> None:
  334. self.db = db
  335. def start(
  336. self,
  337. *,
  338. run_id: str,
  339. node_run_id: str | None,
  340. parent_span_id: str | None,
  341. span_type: str,
  342. name: str,
  343. attributes_json: dict[str, JSONValue] | None = None) -> TraceSpan:
  344. entity = TraceSpan(
  345. run_id=run_id,
  346. node_run_id=node_run_id,
  347. parent_span_id=parent_span_id,
  348. span_type=span_type,
  349. name=name,
  350. status="running",
  351. started_time=datetime.utcnow(),
  352. attributes_json=attributes_json)
  353. self.db.add(entity)
  354. self.db.commit()
  355. self.db.refresh(entity)
  356. return entity
  357. def finish(
  358. self,
  359. *,
  360. span_id: str,
  361. status: str,
  362. error_code: str | None = None,
  363. error_message: str | None = None,
  364. attributes_json: dict[str, JSONValue] | None = None) -> TraceSpan | None:
  365. entity = self.db.get(TraceSpan, span_id)
  366. if entity is None:
  367. return None
  368. ended_time = datetime.utcnow()
  369. entity.status = status
  370. entity.ended_time = ended_time
  371. entity.duration_ms = int((ended_time - entity.started_time).total_seconds() * 1000)
  372. entity.error_code = error_code
  373. entity.error_message = error_message
  374. if attributes_json is not None:
  375. entity.attributes_json = {
  376. **(entity.attributes_json or {}),
  377. **attributes_json,
  378. }
  379. self.db.commit()
  380. self.db.refresh(entity)
  381. return entity
  382. def list_by_scope(
  383. self,
  384. *,
  385. run_id: str | None = None,
  386. node_run_id: str | None = None,
  387. span_type: str | None = None) -> list[TraceSpan]:
  388. stmt = select(TraceSpan)
  389. if run_id is not None:
  390. stmt = stmt.where(TraceSpan.run_id == run_id)
  391. if node_run_id is not None:
  392. stmt = stmt.where(TraceSpan.node_run_id == node_run_id)
  393. if span_type is not None:
  394. stmt = stmt.where(TraceSpan.span_type == span_type)
  395. stmt = stmt.order_by(TraceSpan.started_time.asc())
  396. return list(self.db.scalars(stmt))