from datetime import datetime from sqlalchemy import func, select from sqlalchemy.orm import Session from core_domain import ( AgentRunStatus, AgentStatus, AgentToolInvocationStatus, AgentVersionStatus) from core_shared import JSONValue from app.db.models import AgentDefinition, AgentRun, AgentToolInvocation, AgentVersion class AgentDefinitionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, code: str, name: str, description: str | None, agent_type: str, owner_user_id: str | None, metadata_json: dict[str, JSONValue] | None) -> AgentDefinition: entity = AgentDefinition( code=code, name=name, description=description, agent_type=agent_type, owner_user_id=owner_user_id, metadata_json=metadata_json) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_all(self) -> list[AgentDefinition]: stmt = ( select(AgentDefinition) .order_by(AgentDefinition.created_time.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, agent_id: str) -> AgentDefinition | None: stmt = ( select(AgentDefinition) .where(AgentDefinition.id == agent_id) ) return self.db.scalar(stmt) def update_status( self, *, agent_id: str, status: AgentStatus) -> AgentDefinition | None: entity = self.get_by_id(agent_id=agent_id) if entity is None: return None entity.status = status self.db.commit() self.db.refresh(entity) return entity class AgentVersionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, agent_id: str, status: AgentVersionStatus, role: str, goal: str | None, system_prompt: str, model_config_json: dict[str, JSONValue], memory_policy_json: dict[str, JSONValue], tool_refs_json: list[dict[str, JSONValue]], skill_refs_json: list[dict[str, JSONValue]]) -> AgentVersion: version_no = self._next_version_no(agent_id) entity = AgentVersion( agent_id=agent_id, version_no=version_no, status=status, role=role, goal=goal, system_prompt=system_prompt, model_config_json=model_config_json, memory_policy_json=memory_policy_json, tool_refs_json=tool_refs_json, skill_refs_json=skill_refs_json, published_time=datetime.utcnow() if status == "published" else None) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_agent(self, *, agent_id: str) -> list[AgentVersion]: stmt = ( select(AgentVersion) .where(AgentVersion.agent_id == agent_id) .order_by(AgentVersion.version_no.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, agent_version_id: str) -> AgentVersion | None: stmt = ( select(AgentVersion) .where(AgentVersion.id == agent_version_id) ) return self.db.scalar(stmt) def get_latest_published(self, *, agent_id: str) -> AgentVersion | None: stmt = ( select(AgentVersion) .where(AgentVersion.agent_id == agent_id) .where(AgentVersion.status == "published") .order_by(AgentVersion.version_no.desc()) .limit(1) ) return self.db.scalar(stmt) def _next_version_no(self, agent_id: str) -> int: stmt = select(func.max(AgentVersion.version_no)).where(AgentVersion.agent_id == agent_id) current_max = self.db.scalar(stmt) return (current_max or 0) + 1 class AgentRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, agent_id: str, agent_version_id: str, session_id: str | None, input_text: str | None, input_json: dict[str, JSONValue] | None) -> AgentRun: now = datetime.utcnow() entity = AgentRun( agent_id=agent_id, agent_version_id=agent_version_id, session_id=session_id, input_text=input_text, input_json=input_json, status="queued", queued_time=now) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, agent_id: str | None = None, session_id: str | None = None) -> list[AgentRun]: stmt = select(AgentRun) if agent_id is not None: stmt = stmt.where(AgentRun.agent_id == agent_id) if session_id is not None: stmt = stmt.where(AgentRun.session_id == session_id) stmt = stmt.order_by(AgentRun.created_time.desc()) return list(self.db.scalars(stmt)) def get_by_id(self, *, agent_run_id: str) -> AgentRun | None: stmt = ( select(AgentRun) .where(AgentRun.id == agent_run_id) ) return self.db.scalar(stmt) def claim_next_queued( self, *, worker_key: str, lease_expire_time: datetime) -> AgentRun | None: stmt = ( select(AgentRun) .where(AgentRun.status == "queued") .order_by(AgentRun.created_time.asc()) .with_for_update(skip_locked=True) .limit(1) ) entity = self.db.scalar(stmt) if entity is None: return None now = datetime.utcnow() entity.status = "running" entity.worker_key = worker_key entity.started_time = entity.started_time or now entity.lease_expire_time = lease_expire_time self.db.commit() self.db.refresh(entity) return entity def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int: stmt = ( select(AgentRun) .where(AgentRun.status == "running") .where(AgentRun.lease_expire_time.is_not(None)) .where(AgentRun.lease_expire_time <= now_time) .order_by(AgentRun.lease_expire_time.asc()) .limit(max_items) ) entities = list(self.db.scalars(stmt)) for entity in entities: entity.status = "queued" entity.worker_key = None entity.lease_expire_time = None entity.queued_time = now_time entity.started_time = None entity.finished_time = None if entities: self.db.commit() return len(entities) def update_status( self, *, agent_run_id: str, status: AgentRunStatus, worker_key: str | None = None, output_text: str | None = None, output_json: dict[str, JSONValue] | None = None, error_code: str | None = None, error_message: str | None = None) -> AgentRun | None: entity = self.db.get(AgentRun, agent_run_id) if entity is None: return None now = datetime.utcnow() entity.status = status entity.worker_key = worker_key entity.output_text = output_text entity.output_json = output_json entity.error_code = error_code entity.error_message = error_message if status == "running" and entity.started_time is None: entity.started_time = now if status in {"completed", "failed", "cancelled"}: entity.finished_time = now entity.lease_expire_time = None self.db.commit() self.db.refresh(entity) return entity class AgentToolInvocationRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, agent_run_id: str, agent_id: str, agent_version_id: str, tool_code: str | None, tool_binding_id: str | None, status: AgentToolInvocationStatus, reason: str | None = None, input_json: dict[str, JSONValue] | None = None) -> AgentToolInvocation: entity = AgentToolInvocation( agent_run_id=agent_run_id, agent_id=agent_id, agent_version_id=agent_version_id, tool_code=tool_code, tool_binding_id=tool_binding_id, status=status, reason=reason, input_json=input_json or {}) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_run( self, *, agent_run_id: str) -> list[AgentToolInvocation]: stmt = ( select(AgentToolInvocation) .where(AgentToolInvocation.agent_run_id == agent_run_id) .order_by(AgentToolInvocation.created_time.asc()) ) return list(self.db.scalars(stmt)) def update_status( self, *, invocation_id: str, status: AgentToolInvocationStatus, reason: str | None = None, output_text: str | None = None, output_json: dict[str, JSONValue] | None = None, error_message: str | None = None) -> AgentToolInvocation | None: entity = self.db.get(AgentToolInvocation, invocation_id) if entity is None: return None now = datetime.utcnow() entity.status = status entity.reason = reason entity.output_text = output_text entity.output_json = output_json entity.error_message = error_message if status == "running" and entity.started_time is None: entity.started_time = now if status in {"completed", "failed", "skipped"}: if entity.started_time is None: entity.started_time = now entity.finished_time = now self.db.commit() self.db.refresh(entity) return entity