repositories.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import NodeRun, WorkflowRun
  5. from core_domain import NodeRunStatus, WorkflowRunStatus
  6. class WorkflowRunRepository:
  7. def __init__(self, db: Session) -> None:
  8. self.db = db
  9. def create(
  10. self,
  11. *,
  12. tenant_id: str,
  13. app_id: str,
  14. app_version_id: str,
  15. workflow_id: str,
  16. workflow_version_id: str,
  17. session_id: str | None,
  18. parent_run_id: str | None,
  19. root_run_id: str | None,
  20. run_type: str,
  21. trigger_type: str,
  22. priority: int,
  23. ) -> WorkflowRun:
  24. now = datetime.utcnow()
  25. entity = WorkflowRun(
  26. tenant_id=tenant_id,
  27. app_id=app_id,
  28. app_version_id=app_version_id,
  29. workflow_id=workflow_id,
  30. workflow_version_id=workflow_version_id,
  31. session_id=session_id,
  32. parent_run_id=parent_run_id,
  33. root_run_id=root_run_id,
  34. run_type=run_type,
  35. trigger_type=trigger_type,
  36. priority=priority,
  37. status="running",
  38. started_time=now,
  39. )
  40. self.db.add(entity)
  41. self.db.commit()
  42. if entity.root_run_id is None:
  43. entity.root_run_id = entity.id
  44. self.db.commit()
  45. self.db.refresh(entity)
  46. return entity
  47. def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]:
  48. stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id)
  49. if session_id:
  50. stmt = stmt.where(WorkflowRun.session_id == session_id)
  51. stmt = stmt.order_by(WorkflowRun.created_time.desc())
  52. return list(self.db.scalars(stmt))
  53. def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
  54. entity = self.db.get(WorkflowRun, run_id)
  55. if entity is None:
  56. return
  57. entity.current_node_count = current_node_count
  58. self.db.commit()
  59. def get_by_id(self, run_id: str) -> WorkflowRun | None:
  60. return self.db.get(WorkflowRun, run_id)
  61. def update_status(
  62. self,
  63. *,
  64. run_id: str,
  65. status: WorkflowRunStatus,
  66. error_code: str | None = None,
  67. error_message: str | None = None,
  68. ) -> WorkflowRun | None:
  69. entity = self.db.get(WorkflowRun, run_id)
  70. if entity is None:
  71. return None
  72. entity.status = status
  73. entity.error_code = error_code
  74. entity.error_message = error_message
  75. now = datetime.utcnow()
  76. if status == "running" and entity.started_time is None:
  77. entity.started_time = now
  78. if status in {"completed", "failed", "cancelled"}:
  79. entity.finished_time = now
  80. self.db.commit()
  81. self.db.refresh(entity)
  82. return entity
  83. class NodeRunRepository:
  84. def __init__(self, db: Session) -> None:
  85. self.db = db
  86. def create(
  87. self,
  88. *,
  89. tenant_id: str,
  90. run_id: str,
  91. node_id: str,
  92. node_type: str,
  93. status: str,
  94. ) -> NodeRun:
  95. now = datetime.utcnow()
  96. entity = NodeRun(
  97. tenant_id=tenant_id,
  98. run_id=run_id,
  99. node_id=node_id,
  100. node_type=node_type,
  101. status=status,
  102. queued_time=now,
  103. )
  104. self.db.add(entity)
  105. self.db.commit()
  106. self.db.refresh(entity)
  107. return entity
  108. def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]:
  109. stmt = (
  110. select(NodeRun)
  111. .where(NodeRun.tenant_id == tenant_id)
  112. .where(NodeRun.run_id == run_id)
  113. .order_by(NodeRun.created_time.asc())
  114. )
  115. return list(self.db.scalars(stmt))
  116. def list_by_run_and_node_ids(
  117. self,
  118. *,
  119. tenant_id: str,
  120. run_id: str,
  121. node_ids: list[str],
  122. ) -> list[NodeRun]:
  123. if not node_ids:
  124. return []
  125. stmt = (
  126. select(NodeRun)
  127. .where(NodeRun.tenant_id == tenant_id)
  128. .where(NodeRun.run_id == run_id)
  129. .where(NodeRun.node_id.in_(node_ids))
  130. )
  131. return list(self.db.scalars(stmt))
  132. def get_by_id(self, node_run_id: str) -> NodeRun | None:
  133. return self.db.get(NodeRun, node_run_id)
  134. def update_status(
  135. self,
  136. *,
  137. node_run_id: str,
  138. status: NodeRunStatus,
  139. worker_key: str | None = None,
  140. error_code: str | None = None,
  141. error_message: str | None = None,
  142. ) -> NodeRun | None:
  143. entity = self.db.get(NodeRun, node_run_id)
  144. if entity is None:
  145. return None
  146. entity.status = status
  147. entity.worker_key = worker_key
  148. entity.error_code = error_code
  149. entity.error_message = error_message
  150. now = datetime.utcnow()
  151. if status == "running" and entity.started_time is None:
  152. entity.started_time = now
  153. if status in {"completed", "failed", "skipped"}:
  154. entity.finished_time = now
  155. self.db.commit()
  156. self.db.refresh(entity)
  157. return entity