repositories.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from datetime import datetime
  2. from sqlalchemy import select
  3. from sqlalchemy.orm import Session
  4. from app.db.models import NodeRun, WorkflowRun
  5. class WorkflowRunRepository:
  6. def __init__(self, db: Session) -> None:
  7. self.db = db
  8. def create(
  9. self,
  10. *,
  11. tenant_id: str,
  12. app_id: str,
  13. app_version_id: str,
  14. workflow_id: str,
  15. workflow_version_id: str,
  16. session_id: str | None,
  17. parent_run_id: str | None,
  18. root_run_id: str | None,
  19. run_type: str,
  20. trigger_type: str,
  21. priority: int,
  22. ) -> WorkflowRun:
  23. now = datetime.utcnow()
  24. entity = WorkflowRun(
  25. tenant_id=tenant_id,
  26. app_id=app_id,
  27. app_version_id=app_version_id,
  28. workflow_id=workflow_id,
  29. workflow_version_id=workflow_version_id,
  30. session_id=session_id,
  31. parent_run_id=parent_run_id,
  32. root_run_id=root_run_id,
  33. run_type=run_type,
  34. trigger_type=trigger_type,
  35. priority=priority,
  36. status="running",
  37. started_time=now,
  38. )
  39. self.db.add(entity)
  40. self.db.commit()
  41. if entity.root_run_id is None:
  42. entity.root_run_id = entity.id
  43. self.db.commit()
  44. self.db.refresh(entity)
  45. return entity
  46. def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]:
  47. stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id)
  48. if session_id:
  49. stmt = stmt.where(WorkflowRun.session_id == session_id)
  50. stmt = stmt.order_by(WorkflowRun.created_time.desc())
  51. return list(self.db.scalars(stmt))
  52. def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
  53. entity = self.db.get(WorkflowRun, run_id)
  54. if entity is None:
  55. return
  56. entity.current_node_count = current_node_count
  57. self.db.commit()
  58. class NodeRunRepository:
  59. def __init__(self, db: Session) -> None:
  60. self.db = db
  61. def create(
  62. self,
  63. *,
  64. tenant_id: str,
  65. run_id: str,
  66. node_id: str,
  67. node_type: str,
  68. status: str,
  69. ) -> NodeRun:
  70. now = datetime.utcnow()
  71. entity = NodeRun(
  72. tenant_id=tenant_id,
  73. run_id=run_id,
  74. node_id=node_id,
  75. node_type=node_type,
  76. status=status,
  77. queued_time=now,
  78. )
  79. self.db.add(entity)
  80. self.db.commit()
  81. self.db.refresh(entity)
  82. return entity
  83. def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]:
  84. stmt = (
  85. select(NodeRun)
  86. .where(NodeRun.tenant_id == tenant_id)
  87. .where(NodeRun.run_id == run_id)
  88. .order_by(NodeRun.created_time.asc())
  89. )
  90. return list(self.db.scalars(stmt))