| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- from datetime import datetime
- from sqlalchemy import func, select
- from sqlalchemy.orm import Session
- from core_domain import (
- AgentRunStatus,
- AgentStatus,
- AgentToolInvocationStatus,
- AgentVersionStatus)
- from core_shared import JSONValue
- from app.db.models import AgentDefinition, AgentRun, AgentToolInvocation, AgentVersion
- class AgentDefinitionRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- code: str,
- name: str,
- description: str | None,
- agent_type: str,
- owner_user_id: str | None,
- metadata_json: dict[str, JSONValue] | None) -> AgentDefinition:
- entity = AgentDefinition(
- code=code,
- name=name,
- description=description,
- agent_type=agent_type,
- owner_user_id=owner_user_id,
- metadata_json=metadata_json)
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_all(self) -> list[AgentDefinition]:
- stmt = (
- select(AgentDefinition)
- .order_by(AgentDefinition.created_time.desc())
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, *, agent_id: str) -> AgentDefinition | None:
- stmt = (
- select(AgentDefinition)
- .where(AgentDefinition.id == agent_id)
- )
- return self.db.scalar(stmt)
- def update_status(
- self,
- *,
- agent_id: str,
- status: AgentStatus) -> AgentDefinition | None:
- entity = self.get_by_id(agent_id=agent_id)
- if entity is None:
- return None
- entity.status = status
- self.db.commit()
- self.db.refresh(entity)
- return entity
- class AgentVersionRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- agent_id: str,
- status: AgentVersionStatus,
- role: str,
- goal: str | None,
- system_prompt: str,
- model_config_json: dict[str, JSONValue],
- memory_policy_json: dict[str, JSONValue],
- tool_refs_json: list[dict[str, JSONValue]],
- skill_refs_json: list[dict[str, JSONValue]]) -> AgentVersion:
- version_no = self._next_version_no(agent_id)
- entity = AgentVersion(
- agent_id=agent_id,
- version_no=version_no,
- status=status,
- role=role,
- goal=goal,
- system_prompt=system_prompt,
- model_config_json=model_config_json,
- memory_policy_json=memory_policy_json,
- tool_refs_json=tool_refs_json,
- skill_refs_json=skill_refs_json,
- published_time=datetime.utcnow() if status == "published" else None)
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_agent(self, *, agent_id: str) -> list[AgentVersion]:
- stmt = (
- select(AgentVersion)
- .where(AgentVersion.agent_id == agent_id)
- .order_by(AgentVersion.version_no.desc())
- )
- return list(self.db.scalars(stmt))
- def get_by_id(self, *, agent_version_id: str) -> AgentVersion | None:
- stmt = (
- select(AgentVersion)
- .where(AgentVersion.id == agent_version_id)
- )
- return self.db.scalar(stmt)
- def get_latest_published(self, *, agent_id: str) -> AgentVersion | None:
- stmt = (
- select(AgentVersion)
- .where(AgentVersion.agent_id == agent_id)
- .where(AgentVersion.status == "published")
- .order_by(AgentVersion.version_no.desc())
- .limit(1)
- )
- return self.db.scalar(stmt)
- def _next_version_no(self, agent_id: str) -> int:
- stmt = select(func.max(AgentVersion.version_no)).where(AgentVersion.agent_id == agent_id)
- current_max = self.db.scalar(stmt)
- return (current_max or 0) + 1
- class AgentRunRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- agent_id: str,
- agent_version_id: str,
- session_id: str | None,
- input_text: str | None,
- input_json: dict[str, JSONValue] | None) -> AgentRun:
- now = datetime.utcnow()
- entity = AgentRun(
- agent_id=agent_id,
- agent_version_id=agent_version_id,
- session_id=session_id,
- input_text=input_text,
- input_json=input_json,
- status="queued",
- queued_time=now)
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_scope(
- self,
- *,
- agent_id: str | None = None,
- session_id: str | None = None) -> list[AgentRun]:
- stmt = select(AgentRun)
- if agent_id is not None:
- stmt = stmt.where(AgentRun.agent_id == agent_id)
- if session_id is not None:
- stmt = stmt.where(AgentRun.session_id == session_id)
- stmt = stmt.order_by(AgentRun.created_time.desc())
- return list(self.db.scalars(stmt))
- def get_by_id(self, *, agent_run_id: str) -> AgentRun | None:
- stmt = (
- select(AgentRun)
- .where(AgentRun.id == agent_run_id)
- )
- return self.db.scalar(stmt)
- def claim_next_queued(
- self,
- *,
- worker_key: str,
- lease_expire_time: datetime) -> AgentRun | None:
- stmt = (
- select(AgentRun)
- .where(AgentRun.status == "queued")
- .order_by(AgentRun.created_time.asc())
- .with_for_update(skip_locked=True)
- .limit(1)
- )
- entity = self.db.scalar(stmt)
- if entity is None:
- return None
- now = datetime.utcnow()
- entity.status = "running"
- entity.worker_key = worker_key
- entity.started_time = entity.started_time or now
- entity.lease_expire_time = lease_expire_time
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
- stmt = (
- select(AgentRun)
- .where(AgentRun.status == "running")
- .where(AgentRun.lease_expire_time.is_not(None))
- .where(AgentRun.lease_expire_time <= now_time)
- .order_by(AgentRun.lease_expire_time.asc())
- .limit(max_items)
- )
- entities = list(self.db.scalars(stmt))
- for entity in entities:
- entity.status = "queued"
- entity.worker_key = None
- entity.lease_expire_time = None
- entity.queued_time = now_time
- entity.started_time = None
- entity.finished_time = None
- if entities:
- self.db.commit()
- return len(entities)
- def update_status(
- self,
- *,
- agent_run_id: str,
- status: AgentRunStatus,
- worker_key: str | None = None,
- output_text: str | None = None,
- output_json: dict[str, JSONValue] | None = None,
- error_code: str | None = None,
- error_message: str | None = None) -> AgentRun | None:
- entity = self.db.get(AgentRun, agent_run_id)
- if entity is None:
- return None
- now = datetime.utcnow()
- entity.status = status
- entity.worker_key = worker_key
- entity.output_text = output_text
- entity.output_json = output_json
- entity.error_code = error_code
- entity.error_message = error_message
- if status == "running" and entity.started_time is None:
- entity.started_time = now
- if status in {"completed", "failed", "cancelled"}:
- entity.finished_time = now
- entity.lease_expire_time = None
- self.db.commit()
- self.db.refresh(entity)
- return entity
- class AgentToolInvocationRepository:
- def __init__(self, db: Session) -> None:
- self.db = db
- def create(
- self,
- *,
- agent_run_id: str,
- agent_id: str,
- agent_version_id: str,
- tool_code: str | None,
- tool_binding_id: str | None,
- status: AgentToolInvocationStatus,
- reason: str | None = None,
- input_json: dict[str, JSONValue] | None = None) -> AgentToolInvocation:
- entity = AgentToolInvocation(
- agent_run_id=agent_run_id,
- agent_id=agent_id,
- agent_version_id=agent_version_id,
- tool_code=tool_code,
- tool_binding_id=tool_binding_id,
- status=status,
- reason=reason,
- input_json=input_json or {})
- self.db.add(entity)
- self.db.commit()
- self.db.refresh(entity)
- return entity
- def list_by_run(
- self,
- *,
- agent_run_id: str) -> list[AgentToolInvocation]:
- stmt = (
- select(AgentToolInvocation)
- .where(AgentToolInvocation.agent_run_id == agent_run_id)
- .order_by(AgentToolInvocation.created_time.asc())
- )
- return list(self.db.scalars(stmt))
- def update_status(
- self,
- *,
- invocation_id: str,
- status: AgentToolInvocationStatus,
- reason: str | None = None,
- output_text: str | None = None,
- output_json: dict[str, JSONValue] | None = None,
- error_message: str | None = None) -> AgentToolInvocation | None:
- entity = self.db.get(AgentToolInvocation, invocation_id)
- if entity is None:
- return None
- now = datetime.utcnow()
- entity.status = status
- entity.reason = reason
- entity.output_text = output_text
- entity.output_json = output_json
- entity.error_message = error_message
- if status == "running" and entity.started_time is None:
- entity.started_time = now
- if status in {"completed", "failed", "skipped"}:
- if entity.started_time is None:
- entity.started_time = now
- entity.finished_time = now
- self.db.commit()
- self.db.refresh(entity)
- return entity
|