from datetime import datetime from sqlalchemy import select from sqlalchemy.orm import Session from app.db.models import NodeRun, WorkflowRun class WorkflowRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, app_id: str, app_version_id: str, workflow_id: str, workflow_version_id: str, session_id: str | None, parent_run_id: str | None, root_run_id: str | None, run_type: str, trigger_type: str, priority: int, ) -> WorkflowRun: now = datetime.utcnow() entity = WorkflowRun( tenant_id=tenant_id, app_id=app_id, app_version_id=app_version_id, workflow_id=workflow_id, workflow_version_id=workflow_version_id, session_id=session_id, parent_run_id=parent_run_id, root_run_id=root_run_id, run_type=run_type, trigger_type=trigger_type, priority=priority, status="running", started_time=now, ) self.db.add(entity) self.db.commit() if entity.root_run_id is None: entity.root_run_id = entity.id self.db.commit() self.db.refresh(entity) return entity def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]: stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id) if session_id: stmt = stmt.where(WorkflowRun.session_id == session_id) stmt = stmt.order_by(WorkflowRun.created_time.desc()) return list(self.db.scalars(stmt)) def update_node_count(self, *, run_id: str, current_node_count: int) -> None: entity = self.db.get(WorkflowRun, run_id) if entity is None: return entity.current_node_count = current_node_count self.db.commit() class NodeRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, run_id: str, node_id: str, node_type: str, status: str, ) -> NodeRun: now = datetime.utcnow() entity = NodeRun( tenant_id=tenant_id, run_id=run_id, node_id=node_id, node_type=node_type, status=status, queued_time=now, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]: stmt = ( select(NodeRun) .where(NodeRun.tenant_id == tenant_id) .where(NodeRun.run_id == run_id) .order_by(NodeRun.created_time.asc()) ) return list(self.db.scalars(stmt))