from datetime import datetime from sqlalchemy import func, select from sqlalchemy.orm import Session from core_domain import AgentRunStatus, AgentStatus, AgentVersionStatus from core_shared import JSONValue from app.db.models import AgentDefinition, AgentRun, AgentVersion class AgentDefinitionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, code: str, name: str, description: str | None, agent_type: str, owner_user_id: str | None, metadata_json: dict[str, JSONValue] | None, ) -> AgentDefinition: entity = AgentDefinition( tenant_id=tenant_id, code=code, name=name, description=description, agent_type=agent_type, owner_user_id=owner_user_id, metadata_json=metadata_json, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_tenant(self, *, tenant_id: str) -> list[AgentDefinition]: stmt = ( select(AgentDefinition) .where(AgentDefinition.tenant_id == tenant_id) .order_by(AgentDefinition.created_time.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, agent_id: str) -> AgentDefinition | None: stmt = ( select(AgentDefinition) .where(AgentDefinition.tenant_id == tenant_id) .where(AgentDefinition.id == agent_id) ) return self.db.scalar(stmt) def update_status( self, *, tenant_id: str, agent_id: str, status: AgentStatus, ) -> AgentDefinition | None: entity = self.get_by_id(tenant_id=tenant_id, agent_id=agent_id) if entity is None: return None entity.status = status self.db.commit() self.db.refresh(entity) return entity class AgentVersionRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, agent_id: str, status: AgentVersionStatus, role: str, goal: str | None, system_prompt: str, model_config_json: dict[str, JSONValue], memory_policy_json: dict[str, JSONValue], tool_refs_json: list[dict[str, JSONValue]], skill_refs_json: list[dict[str, JSONValue]], ) -> AgentVersion: version_no = self._next_version_no(agent_id) entity = AgentVersion( tenant_id=tenant_id, agent_id=agent_id, version_no=version_no, status=status, role=role, goal=goal, system_prompt=system_prompt, model_config_json=model_config_json, memory_policy_json=memory_policy_json, tool_refs_json=tool_refs_json, skill_refs_json=skill_refs_json, published_time=datetime.utcnow() if status == "published" else None, ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_agent(self, *, tenant_id: str, agent_id: str) -> list[AgentVersion]: stmt = ( select(AgentVersion) .where(AgentVersion.tenant_id == tenant_id) .where(AgentVersion.agent_id == agent_id) .order_by(AgentVersion.version_no.desc()) ) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, agent_version_id: str) -> AgentVersion | None: stmt = ( select(AgentVersion) .where(AgentVersion.tenant_id == tenant_id) .where(AgentVersion.id == agent_version_id) ) return self.db.scalar(stmt) def get_latest_published(self, *, tenant_id: str, agent_id: str) -> AgentVersion | None: stmt = ( select(AgentVersion) .where(AgentVersion.tenant_id == tenant_id) .where(AgentVersion.agent_id == agent_id) .where(AgentVersion.status == "published") .order_by(AgentVersion.version_no.desc()) .limit(1) ) return self.db.scalar(stmt) def _next_version_no(self, agent_id: str) -> int: stmt = select(func.max(AgentVersion.version_no)).where(AgentVersion.agent_id == agent_id) current_max = self.db.scalar(stmt) return (current_max or 0) + 1 class AgentRunRepository: def __init__(self, db: Session) -> None: self.db = db def create( self, *, tenant_id: str, agent_id: str, agent_version_id: str, session_id: str | None, input_text: str | None, input_json: dict[str, JSONValue] | None, ) -> AgentRun: entity = AgentRun( tenant_id=tenant_id, agent_id=agent_id, agent_version_id=agent_version_id, session_id=session_id, input_text=input_text, input_json=input_json, status="queued", ) self.db.add(entity) self.db.commit() self.db.refresh(entity) return entity def list_by_scope( self, *, tenant_id: str, agent_id: str | None = None, session_id: str | None = None, ) -> list[AgentRun]: stmt = select(AgentRun).where(AgentRun.tenant_id == tenant_id) if agent_id is not None: stmt = stmt.where(AgentRun.agent_id == agent_id) if session_id is not None: stmt = stmt.where(AgentRun.session_id == session_id) stmt = stmt.order_by(AgentRun.created_time.desc()) return list(self.db.scalars(stmt)) def get_by_id(self, *, tenant_id: str, agent_run_id: str) -> AgentRun | None: stmt = ( select(AgentRun) .where(AgentRun.tenant_id == tenant_id) .where(AgentRun.id == agent_run_id) ) return self.db.scalar(stmt) def update_status( self, *, agent_run_id: str, status: AgentRunStatus, worker_key: str | None = None, output_text: str | None = None, output_json: dict[str, JSONValue] | None = None, error_code: str | None = None, error_message: str | None = None, ) -> AgentRun | None: entity = self.db.get(AgentRun, agent_run_id) if entity is None: return None now = datetime.utcnow() entity.status = status entity.worker_key = worker_key entity.output_text = output_text entity.output_json = output_json entity.error_code = error_code entity.error_message = error_message if status == "running" and entity.started_time is None: entity.started_time = now if status in {"completed", "failed", "cancelled"}: entity.finished_time = now self.db.commit() self.db.refresh(entity) return entity