from datetime import datetime from core_domain import NodeRunStatus, WorkflowRunStatus from core_shared import JSONValue from sqlalchemy import or_, select from sqlalchemy.orm import Session from app.db.models import ExecutionLog, NodeArtifact, NodeRun, TraceSpan, WorkflowRun class WorkflowRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, app_id: str, app_version_id: str, workflow_id: str, workflow_version_id: str, session_id: str | None, parent_run_id: str | None, root_run_id: str | None, run_type: str, trigger_type: str, priority: int) -> WorkflowRun: now = datetime.utcnow() entity = WorkflowRun( app_id=app_id, app_version_id=app_version_id, workflow_id=workflow_id, workflow_version_id=workflow_version_id, session_id=session_id, parent_run_id=parent_run_id, root_run_id=root_run_id, run_type=run_type, trigger_type=trigger_type, priority=priority, status="running", started_time=now) self.db.add(entity) self.db.commit() if entity.root_run_id is None: entity.root_run_id = entity.id self.db.commit() self.db.refresh(entity) return entity def list_by_scope(self, *, session_id: str | None = None) -> list[WorkflowRun]: stmt = select(WorkflowRun) if session_id: stmt = stmt.where(WorkflowRun.session_id == session_id) stmt = stmt.order_by(WorkflowRun.created_time.desc()) return list(self.db.scalars(stmt)) def update_node_count(self, *, run_id: str, current_node_count: int) -> None: entity = self.db.get(WorkflowRun, run_id) if entity is None: return entity.current_node_count = current_node_count self.db.commit() def get_by_id(self, run_id: str) -> WorkflowRun | None: return self.db.get(WorkflowRun, run_id) def update_status( self, *, run_id: str, status: WorkflowRunStatus, error_code: str | None = None, error_message: str | None = None) -> WorkflowRun | None: entity = self.db.get(WorkflowRun, run_id) if entity is None: return None entity.status = status entity.error_code = error_code entity.error_message = error_message now = datetime.utcnow() if status == "running" and entity.started_time is None: entity.started_time = now if status in {"completed", "failed", "cancelled"}: entity.finished_time = now self.db.commit() self.db.refresh(entity) return entity class NodeRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, run_id: str, node_id: str, node_type: str, status: str, scheduled_time: datetime | None = None, timeout_time: datetime | None = None, parent_node_run_id: str | None = None) -> NodeRun: now = datetime.utcnow() entity = NodeRun( run_id=run_id, parent_node_run_id=parent_node_run_id, node_id=node_id, node_type=node_type, status=status, queued_time=now, scheduled_time=scheduled_time or now, timeout_time=timeout_time) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_run(self, *, run_id: str) -> list[NodeRun]: stmt = ( select(NodeRun) .where(NodeRun.run_id == run_id) .order_by(NodeRun.created_time.asc()) ) return list(self.db.scalars(stmt)) def list_by_run_and_node_ids( self, *, run_id: str, node_ids: list[str]) -> list[NodeRun]: if not node_ids: return [] stmt = ( select(NodeRun) .where(NodeRun.run_id == run_id) .where(NodeRun.node_id.in_(node_ids)) ) return list(self.db.scalars(stmt)) def get_by_id(self, node_run_id: str) -> NodeRun | None: return self.db.get(NodeRun, node_run_id) def get_next_queued_by_run(self, *, run_id: str) -> NodeRun | None: stmt = ( select(NodeRun) .where(NodeRun.run_id == run_id) .where(NodeRun.status == "queued") .where( or_( NodeRun.scheduled_time.is_(None), NodeRun.scheduled_time <= datetime.utcnow()) ) .order_by(NodeRun.created_time.asc()) .limit(1) ) return self.db.scalar(stmt) def claim_next_queued( self, *, worker_key: str, lease_expire_time: datetime) -> NodeRun | None: stmt = ( select(NodeRun) .join(WorkflowRun, NodeRun.run_id == WorkflowRun.id) .where(NodeRun.status == "queued") .where( or_( NodeRun.scheduled_time.is_(None), NodeRun.scheduled_time <= datetime.utcnow()) ) .order_by(WorkflowRun.priority.desc(), NodeRun.created_time.asc()) .with_for_update(skip_locked=True) .limit(1) ) entity = self.db.scalar(stmt) if entity is None: return None now = datetime.utcnow() entity.status = "running" entity.worker_key = worker_key entity.started_time = entity.started_time or now entity.lease_expire_time = lease_expire_time self.db.commit() self.db.refresh(entity) return entity def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int: stmt = ( select(NodeRun) .where(NodeRun.status == "running") .where(NodeRun.lease_expire_time.is_not(None)) .where(NodeRun.lease_expire_time <= now_time) .order_by(NodeRun.lease_expire_time.asc()) .limit(max_items) ) entities = list(self.db.scalars(stmt)) for entity in entities: entity.status = "queued" entity.worker_key = None entity.lease_expire_time = None entity.scheduled_time = now_time entity.queued_time = now_time entity.started_time = None entity.finished_time = None entity.attempt_no += 1 if entities: self.db.commit() return len(entities) def update_status( self, *, node_run_id: str, status: NodeRunStatus, worker_key: str | None = None, error_code: str | None = None, error_message: str | None = None, output_text: str | None = None, output_json: dict[str, JSONValue] | None = None) -> NodeRun | None: entity = self.db.get(NodeRun, node_run_id) if entity is None: return None entity.status = status entity.worker_key = worker_key entity.error_code = error_code entity.error_message = error_message entity.output_text = output_text entity.output_json = output_json now = datetime.utcnow() if status == "running" and entity.started_time is None: entity.started_time = now if status != "running": entity.lease_expire_time = None if status in {"completed", "failed", "skipped"}: entity.finished_time = now self.db.commit() self.db.refresh(entity) return entity def requeue_for_retry( self, *, node_run_id: str, scheduled_time: datetime, timeout_time: datetime | None, error_code: str | None, error_message: str | None, output_text: str | None, output_json: dict[str, JSONValue] | None) -> NodeRun | None: entity = self.db.get(NodeRun, node_run_id) if entity is None: return None entity.status = "queued" entity.attempt_no += 1 entity.worker_key = None entity.lease_expire_time = None entity.scheduled_time = scheduled_time entity.timeout_time = timeout_time entity.queued_time = datetime.utcnow() entity.started_time = None entity.finished_time = None entity.error_code = error_code entity.error_message = error_message entity.output_text = output_text entity.output_json = output_json self.db.commit() self.db.refresh(entity) return entity class ExecutionLogRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, run_id: str, node_run_id: str | None, event_type: str, level: str, message: str, detail_json: dict[str, JSONValue] | None) -> ExecutionLog: entity = ExecutionLog( run_id=run_id, node_run_id=node_run_id, event_type=event_type, level=level, message=message, detail_json=detail_json) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, run_id: str | None = None, node_run_id: str | None = None) -> list[ExecutionLog]: stmt = select(ExecutionLog) if run_id is not None: stmt = stmt.where(ExecutionLog.run_id == run_id) if node_run_id is not None: stmt = stmt.where(ExecutionLog.node_run_id == node_run_id) stmt = stmt.order_by(ExecutionLog.created_time.asc()) return list(self.db.scalars(stmt)) class NodeArtifactRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, run_id: str, node_run_id: str, node_id: str, artifact_type: str, name: str, mime_type: str | None, content_text: str | None, content_json: dict[str, JSONValue] | None, storage_uri: str | None = None, size_bytes: int | None = None) -> NodeArtifact: entity = NodeArtifact( run_id=run_id, node_run_id=node_run_id, node_id=node_id, artifact_type=artifact_type, name=name, mime_type=mime_type, content_text=content_text, content_json=content_json, storage_uri=storage_uri, size_bytes=size_bytes) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, run_id: str | None = None, node_run_id: str | None = None, artifact_type: str | None = None) -> list[NodeArtifact]: stmt = select(NodeArtifact) if run_id is not None: stmt = stmt.where(NodeArtifact.run_id == run_id) if node_run_id is not None: stmt = stmt.where(NodeArtifact.node_run_id == node_run_id) if artifact_type is not None: stmt = stmt.where(NodeArtifact.artifact_type == artifact_type) stmt = stmt.order_by(NodeArtifact.created_time.asc()) return list(self.db.scalars(stmt)) class TraceSpanRepository: def __init__(self, db: Session) -> None: self.db = db def start( self, *, run_id: str, node_run_id: str | None, parent_span_id: str | None, span_type: str, name: str, attributes_json: dict[str, JSONValue] | None = None) -> TraceSpan: entity = TraceSpan( run_id=run_id, node_run_id=node_run_id, parent_span_id=parent_span_id, span_type=span_type, name=name, status="running", started_time=datetime.utcnow(), attributes_json=attributes_json) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def finish( self, *, span_id: str, status: str, error_code: str | None = None, error_message: str | None = None, attributes_json: dict[str, JSONValue] | None = None) -> TraceSpan | None: entity = self.db.get(TraceSpan, span_id) if entity is None: return None ended_time = datetime.utcnow() entity.status = status entity.ended_time = ended_time entity.duration_ms = int((ended_time - entity.started_time).total_seconds() * 1000) entity.error_code = error_code entity.error_message = error_message if attributes_json is not None: entity.attributes_json = { **(entity.attributes_json or {}), **attributes_json, } self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, run_id: str | None = None, node_run_id: str | None = None, span_type: str | None = None) -> list[TraceSpan]: stmt = select(TraceSpan) if run_id is not None: stmt = stmt.where(TraceSpan.run_id == run_id) if node_run_id is not None: stmt = stmt.where(TraceSpan.node_run_id == node_run_id) if span_type is not None: stmt = stmt.where(TraceSpan.span_type == span_type) stmt = stmt.order_by(TraceSpan.started_time.asc()) return list(self.db.scalars(stmt))