| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479 |
- from datetime import datetime
- from sqlalchemy import or_, select
- from sqlalchemy.orm import Session
- from app.db.models import ExecutionLog, NodeArtifact, NodeRun, TraceSpan, WorkflowRun
- from core_domain import NodeRunStatus, WorkflowRunStatus
- from core_shared import JSONValue
- class WorkflowRunRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- app_id: str,
- app_version_id: str,
- workflow_id: str,
- workflow_version_id: str,
- session_id: str | None,
- parent_run_id: str | None,
- root_run_id: str | None,
- run_type: str,
- trigger_type: str,
- priority: int,
- ) -> WorkflowRun:
- now = datetime.utcnow()
- entity = WorkflowRun(
- tenant_id=tenant_id,
- app_id=app_id,
- app_version_id=app_version_id,
- workflow_id=workflow_id,
- workflow_version_id=workflow_version_id,
- session_id=session_id,
- parent_run_id=parent_run_id,
- root_run_id=root_run_id,
- run_type=run_type,
- trigger_type=trigger_type,
- priority=priority,
- status="running",
- started_time=now,
- )
- self.db.add(entity)
- self.db.commit()
- if entity.root_run_id is None:
- entity.root_run_id = entity.id
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(self, *, tenant_id: str, session_id: str | None = None) -> list[WorkflowRun]:
- stmt = select(WorkflowRun).where(WorkflowRun.tenant_id == tenant_id)
- if session_id:
- stmt = stmt.where(WorkflowRun.session_id == session_id)
- stmt = stmt.order_by(WorkflowRun.created_time.desc())
- return list(self.db.scalars(stmt))
- def update_node_count(self, *, run_id: str, current_node_count: int) -> None:
- entity = self.db.get(WorkflowRun, run_id)
- if entity is None:
- return
- entity.current_node_count = current_node_count
- self.db.commit()
- def get_by_id(self, run_id: str) -> WorkflowRun | None:
- return self.db.get(WorkflowRun, run_id)
- def update_status(
- self,
- *,
- run_id: str,
- status: WorkflowRunStatus,
- error_code: str | None = None,
- error_message: str | None = None,
- ) -> WorkflowRun | None:
- entity = self.db.get(WorkflowRun, run_id)
- if entity is None:
- return None
- entity.status = status
- entity.error_code = error_code
- entity.error_message = error_message
- now = datetime.utcnow()
- if status == "running" and entity.started_time is None:
- entity.started_time = now
- if status in {"completed", "failed", "cancelled"}:
- entity.finished_time = now
- self.db.commit()
- self.db.refresh(entity)
- return entity
- class NodeRunRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- run_id: str,
- node_id: str,
- node_type: str,
- status: str,
- scheduled_time: datetime | None = None,
- timeout_time: datetime | None = None,
- parent_node_run_id: str | None = None,
- ) -> NodeRun:
- now = datetime.utcnow()
- entity = NodeRun(
- tenant_id=tenant_id,
- run_id=run_id,
- parent_node_run_id=parent_node_run_id,
- node_id=node_id,
- node_type=node_type,
- status=status,
- queued_time=now,
- scheduled_time=scheduled_time or now,
- timeout_time=timeout_time,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_run(self, *, tenant_id: str, run_id: str) -> list[NodeRun]:
- stmt = (
- select(NodeRun)
- .where(NodeRun.tenant_id == tenant_id)
- .where(NodeRun.run_id == run_id)
- .order_by(NodeRun.created_time.asc())
- )
- return list(self.db.scalars(stmt))
- def list_by_run_and_node_ids(
- self,
- *,
- tenant_id: str,
- run_id: str,
- node_ids: list[str],
- ) -> list[NodeRun]:
- if not node_ids:
- return []
- stmt = (
- select(NodeRun)
- .where(NodeRun.tenant_id == tenant_id)
- .where(NodeRun.run_id == run_id)
- .where(NodeRun.node_id.in_(node_ids))
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, node_run_id: str) -> NodeRun | None:
- return self.db.get(NodeRun, node_run_id)
- def get_next_queued_by_run(self, *, tenant_id: str, run_id: str) -> NodeRun | None:
- stmt = (
- select(NodeRun)
- .where(NodeRun.tenant_id == tenant_id)
- .where(NodeRun.run_id == run_id)
- .where(NodeRun.status == "queued")
- .where(
- or_(
- NodeRun.scheduled_time.is_(None),
- NodeRun.scheduled_time <= datetime.utcnow(),
- )
- )
- .order_by(NodeRun.created_time.asc())
- .limit(1)
- )
- return self.db.scalar(stmt)
- def claim_next_queued(
- self,
- *,
- worker_key: str,
- lease_expire_time: datetime,
- ) -> NodeRun | None:
- stmt = (
- select(NodeRun)
- .join(WorkflowRun, NodeRun.run_id == WorkflowRun.id)
- .where(NodeRun.status == "queued")
- .where(
- or_(
- NodeRun.scheduled_time.is_(None),
- NodeRun.scheduled_time <= datetime.utcnow(),
- )
- )
- .order_by(WorkflowRun.priority.desc(), NodeRun.created_time.asc())
- .with_for_update(skip_locked=True)
- .limit(1)
- )
- entity = self.db.scalar(stmt)
- if entity is None:
- return None
- now = datetime.utcnow()
- entity.status = "running"
- entity.worker_key = worker_key
- entity.started_time = entity.started_time or now
- entity.lease_expire_time = lease_expire_time
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
- stmt = (
- select(NodeRun)
- .where(NodeRun.status == "running")
- .where(NodeRun.lease_expire_time.is_not(None))
- .where(NodeRun.lease_expire_time <= now_time)
- .order_by(NodeRun.lease_expire_time.asc())
- .limit(max_items)
- )
- entities = list(self.db.scalars(stmt))
- for entity in entities:
- entity.status = "queued"
- entity.worker_key = None
- entity.lease_expire_time = None
- entity.scheduled_time = now_time
- entity.queued_time = now_time
- entity.started_time = None
- entity.finished_time = None
- entity.attempt_no += 1
- if entities:
- self.db.commit()
- return len(entities)
- def update_status(
- self,
- *,
- node_run_id: str,
- status: NodeRunStatus,
- worker_key: str | None = None,
- error_code: str | None = None,
- error_message: str | None = None,
- output_text: str | None = None,
- output_json: dict[str, JSONValue] | None = None,
- ) -> NodeRun | None:
- entity = self.db.get(NodeRun, node_run_id)
- if entity is None:
- return None
- entity.status = status
- entity.worker_key = worker_key
- entity.error_code = error_code
- entity.error_message = error_message
- entity.output_text = output_text
- entity.output_json = output_json
- now = datetime.utcnow()
- if status == "running" and entity.started_time is None:
- entity.started_time = now
- if status != "running":
- entity.lease_expire_time = None
- if status in {"completed", "failed", "skipped"}:
- entity.finished_time = now
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def requeue_for_retry(
- self,
- *,
- node_run_id: str,
- scheduled_time: datetime,
- timeout_time: datetime | None,
- error_code: str | None,
- error_message: str | None,
- output_text: str | None,
- output_json: dict[str, JSONValue] | None,
- ) -> NodeRun | None:
- entity = self.db.get(NodeRun, node_run_id)
- if entity is None:
- return None
- entity.status = "queued"
- entity.attempt_no += 1
- entity.worker_key = None
- entity.lease_expire_time = None
- entity.scheduled_time = scheduled_time
- entity.timeout_time = timeout_time
- entity.queued_time = datetime.utcnow()
- entity.started_time = None
- entity.finished_time = None
- entity.error_code = error_code
- entity.error_message = error_message
- entity.output_text = output_text
- entity.output_json = output_json
- self.db.commit()
- self.db.refresh(entity)
- return entity
- class ExecutionLogRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- run_id: str,
- node_run_id: str | None,
- event_type: str,
- level: str,
- message: str,
- detail_json: dict[str, JSONValue] | None,
- ) -> ExecutionLog:
- entity = ExecutionLog(
- tenant_id=tenant_id,
- run_id=run_id,
- node_run_id=node_run_id,
- event_type=event_type,
- level=level,
- message=message,
- detail_json=detail_json,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(
- self,
- *,
- tenant_id: str,
- run_id: str | None = None,
- node_run_id: str | None = None,
- ) -> list[ExecutionLog]:
- stmt = select(ExecutionLog).where(ExecutionLog.tenant_id == tenant_id)
- if run_id is not None:
- stmt = stmt.where(ExecutionLog.run_id == run_id)
- if node_run_id is not None:
- stmt = stmt.where(ExecutionLog.node_run_id == node_run_id)
- stmt = stmt.order_by(ExecutionLog.created_time.asc())
- return list(self.db.scalars(stmt))
- class NodeArtifactRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- tenant_id: str,
- run_id: str,
- node_run_id: str,
- node_id: str,
- artifact_type: str,
- name: str,
- mime_type: str | None,
- content_text: str | None,
- content_json: dict[str, JSONValue] | None,
- storage_uri: str | None = None,
- size_bytes: int | None = None,
- ) -> NodeArtifact:
- entity = NodeArtifact(
- tenant_id=tenant_id,
- run_id=run_id,
- node_run_id=node_run_id,
- node_id=node_id,
- artifact_type=artifact_type,
- name=name,
- mime_type=mime_type,
- content_text=content_text,
- content_json=content_json,
- storage_uri=storage_uri,
- size_bytes=size_bytes,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(
- self,
- *,
- tenant_id: str,
- run_id: str | None = None,
- node_run_id: str | None = None,
- artifact_type: str | None = None,
- ) -> list[NodeArtifact]:
- stmt = select(NodeArtifact).where(NodeArtifact.tenant_id == tenant_id)
- if run_id is not None:
- stmt = stmt.where(NodeArtifact.run_id == run_id)
- if node_run_id is not None:
- stmt = stmt.where(NodeArtifact.node_run_id == node_run_id)
- if artifact_type is not None:
- stmt = stmt.where(NodeArtifact.artifact_type == artifact_type)
- stmt = stmt.order_by(NodeArtifact.created_time.asc())
- return list(self.db.scalars(stmt))
- class TraceSpanRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def start(
- self,
- *,
- tenant_id: str,
- run_id: str,
- node_run_id: str | None,
- parent_span_id: str | None,
- span_type: str,
- name: str,
- attributes_json: dict[str, JSONValue] | None = None,
- ) -> TraceSpan:
- entity = TraceSpan(
- tenant_id=tenant_id,
- run_id=run_id,
- node_run_id=node_run_id,
- parent_span_id=parent_span_id,
- span_type=span_type,
- name=name,
- status="running",
- started_time=datetime.utcnow(),
- attributes_json=attributes_json,
- )
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def finish(
- self,
- *,
- span_id: str,
- status: str,
- error_code: str | None = None,
- error_message: str | None = None,
- attributes_json: dict[str, JSONValue] | None = None,
- ) -> TraceSpan | None:
- entity = self.db.get(TraceSpan, span_id)
- if entity is None:
- return None
- ended_time = datetime.utcnow()
- entity.status = status
- entity.ended_time = ended_time
- entity.duration_ms = int((ended_time - entity.started_time).total_seconds() * 1000)
- entity.error_code = error_code
- entity.error_message = error_message
- if attributes_json is not None:
- entity.attributes_json = {
- **(entity.attributes_json or {}),
- **attributes_json,
- }
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(
- self,
- *,
- tenant_id: str,
- run_id: str | None = None,
- node_run_id: str | None = None,
- span_type: str | None = None,
- ) -> list[TraceSpan]:
- stmt = select(TraceSpan).where(TraceSpan.tenant_id == tenant_id)
- if run_id is not None:
- stmt = stmt.where(TraceSpan.run_id == run_id)
- if node_run_id is not None:
- stmt = stmt.where(TraceSpan.node_run_id == node_run_id)
- if span_type is not None:
- stmt = stmt.where(TraceSpan.span_type == span_type)
- stmt = stmt.order_by(TraceSpan.started_time.asc())
- return list(self.db.scalars(stmt))
|