from datetime import datetime from sqlalchemy import select from sqlalchemy.orm import Session from app.db.models import NodeRun, WorkflowRun from core_domain import NodeRunStatus, WorkflowRunStatus class WorkflowRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, app_id: str, app_version_id: str, workflow_id: str, workflow_version_id: str, session_id: str | None, parent_run_id: str | None, root_run_id: str | None, run_type: str, trigger_type: str, priority: int, ) -> WorkflowRun: now = datetime.utcnow() entity = WorkflowRun( tenant_id=tenant_id, app_id=app_id, app_version_id=app_version_id, workflow_id=workflow_id, workflow_version_id=workflow_version_id, session_id=session_id, parent_run_id=parent_run_id, root_run_id=root_run_id, run_type=run_type, trigger_type=trigger_type, priority=priority, status="running", started_time=now, ) self.db.add(entity) self.db.commit() if entity.root_run_id is None: entity.root_run_id = entity.id self.db.commit() self.db.refresh(entity) return entity def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]: stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id) if session_id: stmt = stmt.where(WorkflowRun.session_id == session_id) stmt = stmt.order_by(WorkflowRun.created_time.desc()) return list(self.db.scalars(stmt)) def update_node_count(self, *, run_id: str, current_node_count: int) -> None: entity = self.db.get(WorkflowRun, run_id) if entity is None: return entity.current_node_count = current_node_count self.db.commit() def get_by_id(self, run_id: str) -> WorkflowRun | None: return self.db.get(WorkflowRun, run_id) def update_status( self, *, run_id: str, status: WorkflowRunStatus, error_code: str | None = None, error_message: str | None = None, ) -> WorkflowRun | None: entity = self.db.get(WorkflowRun, run_id) if entity is None: return None entity.status = status entity.error_code = error_code entity.error_message = error_message now = datetime.utcnow() if status == "running" and entity.started_time is None: entity.started_time = now if status in {"completed", "failed", "cancelled"}: entity.finished_time = now self.db.commit() self.db.refresh(entity) return entity class NodeRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, run_id: str, node_id: str, node_type: str, status: str, ) -> NodeRun: now = datetime.utcnow() entity = NodeRun( tenant_id=tenant_id, run_id=run_id, node_id=node_id, node_type=node_type, status=status, queued_time=now, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]: stmt = ( select(NodeRun) .where(NodeRun.tenant_id == tenant_id) .where(NodeRun.run_id == run_id) .order_by(NodeRun.created_time.asc()) ) return list(self.db.scalars(stmt)) def list_by_run_and_node_ids( self, *, tenant_id: str, run_id: str, node_ids: list[str], ) -> list[NodeRun]: if not node_ids: return [] stmt = ( select(NodeRun) .where(NodeRun.tenant_id == tenant_id) .where(NodeRun.run_id == run_id) .where(NodeRun.node_id.in_(node_ids)) ) return list(self.db.scalars(stmt)) def get_by_id(self, node_run_id: str) -> NodeRun | None: return self.db.get(NodeRun, node_run_id) def update_status( self, *, node_run_id: str, status: NodeRunStatus, worker_key: str | None = None, error_code: str | None = None, error_message: str | None = None, ) -> NodeRun | None: entity = self.db.get(NodeRun, node_run_id) if entity is None: return None entity.status = status entity.worker_key = worker_key entity.error_code = error_code entity.error_message = error_message now = datetime.utcnow() if status == "running" and entity.started_time is None: entity.started_time = now if status in {"completed", "failed", "skipped"}: entity.finished_time = now self.db.commit() self.db.refresh(entity) return entity