| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- from datetime import datetime
- from sqlalchemy import select
- from sqlalchemy.orm import Session
- from app.db.models import NodeRun, WorkflowRun
- from core_domain import NodeRunStatus, WorkflowRunStatus
- class WorkflowRunRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- app_id: str,
- app_version_id: str,
- workflow_id: str,
- workflow_version_id: str,
- session_id: str | None,
- parent_run_id: str | None,
- root_run_id: str | None,
- run_type: str,
- trigger_type: str,
- priority: int,
- ) -> WorkflowRun:
- now = datetime.utcnow()
- entity = WorkflowRun(
- tenant_id=tenant_id,
- app_id=app_id,
- app_version_id=app_version_id,
- workflow_id=workflow_id,
- workflow_version_id=workflow_version_id,
- session_id=session_id,
- parent_run_id=parent_run_id,
- root_run_id=root_run_id,
- run_type=run_type,
- trigger_type=trigger_type,
- priority=priority,
- status="running",
- started_time=now,
- )
- self.db.add(entity)
- self.db.commit()
- if entity.root_run_id is None:
- entity.root_run_id = entity.id
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]:
- stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id)
- if session_id:
- stmt = stmt.where(WorkflowRun.session_id == session_id)
- stmt = stmt.order_by(WorkflowRun.created_time.desc())
- return list(self.db.scalars(stmt))
- def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
- entity = self.db.get(WorkflowRun, run_id)
- if entity is None:
- return
- entity.current_node_count = current_node_count
- self.db.commit()
- def get_by_id(self, run_id: str) -> WorkflowRun | None:
- return self.db.get(WorkflowRun, run_id)
- def update_status(
- self,
- *,
- run_id: str,
- status: WorkflowRunStatus,
- error_code: str | None = None,
- error_message: str | None = None,
- ) -> WorkflowRun | None:
- entity = self.db.get(WorkflowRun, run_id)
- if entity is None:
- return None
- entity.status = status
- entity.error_code = error_code
- entity.error_message = error_message
- now = datetime.utcnow()
- if status == "running" and entity.started_time is None:
- entity.started_time = now
- if status in {"completed", "failed", "cancelled"}:
- entity.finished_time = now
- self.db.commit()
- self.db.refresh(entity)
- return entity
- class NodeRunRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- run_id: str,
- node_id: str,
- node_type: str,
- status: str,
- ) -> NodeRun:
- now = datetime.utcnow()
- entity = NodeRun(
- tenant_id=tenant_id,
- run_id=run_id,
- node_id=node_id,
- node_type=node_type,
- status=status,
- queued_time=now,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]:
- stmt = (
- select(NodeRun)
- .where(NodeRun.tenant_id == tenant_id)
- .where(NodeRun.run_id == run_id)
- .order_by(NodeRun.created_time.asc())
- )
- return list(self.db.scalars(stmt))
- def list_by_run_and_node_ids(
- self,
- *,
- tenant_id: str,
- run_id: str,
- node_ids: list[str],
- ) -> list[NodeRun]:
- if not node_ids:
- return []
- stmt = (
- select(NodeRun)
- .where(NodeRun.tenant_id == tenant_id)
- .where(NodeRun.run_id == run_id)
- .where(NodeRun.node_id.in_(node_ids))
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, node_run_id: str) -> NodeRun | None:
- return self.db.get(NodeRun, node_run_id)
- def update_status(
- self,
- *,
- node_run_id: str,
- status: NodeRunStatus,
- worker_key: str | None = None,
- error_code: str | None = None,
- error_message: str | None = None,
- ) -> NodeRun | None:
- entity = self.db.get(NodeRun, node_run_id)
- if entity is None:
- return None
- entity.status = status
- entity.worker_key = worker_key
- entity.error_code = error_code
- entity.error_message = error_message
- now = datetime.utcnow()
- if status == "running" and entity.started_time is None:
- entity.started_time = now
- if status in {"completed", "failed", "skipped"}:
- entity.finished_time = now
- self.db.commit()
- self.db.refresh(entity)
- return entity
|