| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from datetime import datetime
- from sqlalchemy import select
- from sqlalchemy.orm import Session
- from app.db.models import NodeRun, WorkflowRun
- class WorkflowRunRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- app_id: str,
- app_version_id: str,
- workflow_id: str,
- workflow_version_id: str,
- session_id: str | None,
- parent_run_id: str | None,
- root_run_id: str | None,
- run_type: str,
- trigger_type: str,
- priority: int,
- ) -> WorkflowRun:
- now = datetime.utcnow()
- entity = WorkflowRun(
- tenant_id=tenant_id,
- app_id=app_id,
- app_version_id=app_version_id,
- workflow_id=workflow_id,
- workflow_version_id=workflow_version_id,
- session_id=session_id,
- parent_run_id=parent_run_id,
- root_run_id=root_run_id,
- run_type=run_type,
- trigger_type=trigger_type,
- priority=priority,
- status="running",
- started_time=now,
- )
- self.db.add(entity)
- self.db.commit()
- if entity.root_run_id is None:
- entity.root_run_id = entity.id
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]:
- stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id)
- if session_id:
- stmt = stmt.where(WorkflowRun.session_id == session_id)
- stmt = stmt.order_by(WorkflowRun.created_time.desc())
- return list(self.db.scalars(stmt))
- def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
- entity = self.db.get(WorkflowRun, run_id)
- if entity is None:
- return
- entity.current_node_count = current_node_count
- self.db.commit()
- class NodeRunRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- run_id: str,
- node_id: str,
- node_type: str,
- status: str,
- ) -> NodeRun:
- now = datetime.utcnow()
- entity = NodeRun(
- tenant_id=tenant_id,
- run_id=run_id,
- node_id=node_id,
- node_type=node_type,
- status=status,
- queued_time=now,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]:
- stmt = (
- select(NodeRun)
- .where(NodeRun.tenant_id == tenant_id)
- .where(NodeRun.run_id == run_id)
- .order_by(NodeRun.created_time.asc())
- )
- return list(self.db.scalars(stmt))
|