repositories.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from datetime import datetime
  2. from sqlalchemy import func, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import AgentRunStatus, AgentStatus, AgentVersionStatus
  5. from core_shared import JSONValue
  6. from app.db.models import AgentDefinition, AgentRun, AgentVersion
  7. class AgentDefinitionRepository:
  8. def __init__(self, db: Session) -> None:
  9. self.db = db
  10. def create(
  11. self,
  12. *,
  13. tenant_id: str,
  14. code: str,
  15. name: str,
  16. description: str | None,
  17. agent_type: str,
  18. owner_user_id: str | None,
  19. metadata_json: dict[str, JSONValue] | None,
  20. ) -> AgentDefinition:
  21. entity = AgentDefinition(
  22. tenant_id=tenant_id,
  23. code=code,
  24. name=name,
  25. description=description,
  26. agent_type=agent_type,
  27. owner_user_id=owner_user_id,
  28. metadata_json=metadata_json,
  29. )
  30. self.db.add(entity)
  31. self.db.commit()
  32. self.db.refresh(entity)
  33. return entity
  34. def list_by_tenant(self, *, tenant_id: str) -> list[AgentDefinition]:
  35. stmt = (
  36. select(AgentDefinition)
  37. .where(AgentDefinition.tenant_id == tenant_id)
  38. .order_by(AgentDefinition.created_time.desc())
  39. )
  40. return list(self.db.scalars(stmt))
  41. def get_by_id(self, *, tenant_id: str, agent_id: str) -> AgentDefinition | None:
  42. stmt = (
  43. select(AgentDefinition)
  44. .where(AgentDefinition.tenant_id == tenant_id)
  45. .where(AgentDefinition.id == agent_id)
  46. )
  47. return self.db.scalar(stmt)
  48. def update_status(
  49. self,
  50. *,
  51. tenant_id: str,
  52. agent_id: str,
  53. status: AgentStatus,
  54. ) -> AgentDefinition | None:
  55. entity = self.get_by_id(tenant_id=tenant_id, agent_id=agent_id)
  56. if entity is None:
  57. return None
  58. entity.status = status
  59. self.db.commit()
  60. self.db.refresh(entity)
  61. return entity
  62. class AgentVersionRepository:
  63. def __init__(self, db: Session) -> None:
  64. self.db = db
  65. def create(
  66. self,
  67. *,
  68. tenant_id: str,
  69. agent_id: str,
  70. status: AgentVersionStatus,
  71. role: str,
  72. goal: str | None,
  73. system_prompt: str,
  74. model_config_json: dict[str, JSONValue],
  75. memory_policy_json: dict[str, JSONValue],
  76. tool_refs_json: list[dict[str, JSONValue]],
  77. skill_refs_json: list[dict[str, JSONValue]],
  78. ) -> AgentVersion:
  79. version_no = self._next_version_no(agent_id)
  80. entity = AgentVersion(
  81. tenant_id=tenant_id,
  82. agent_id=agent_id,
  83. version_no=version_no,
  84. status=status,
  85. role=role,
  86. goal=goal,
  87. system_prompt=system_prompt,
  88. model_config_json=model_config_json,
  89. memory_policy_json=memory_policy_json,
  90. tool_refs_json=tool_refs_json,
  91. skill_refs_json=skill_refs_json,
  92. published_time=datetime.utcnow() if status == "published" else None,
  93. )
  94. self.db.add(entity)
  95. self.db.commit()
  96. self.db.refresh(entity)
  97. return entity
  98. def list_by_agent(self, *, tenant_id: str, agent_id: str) -> list[AgentVersion]:
  99. stmt = (
  100. select(AgentVersion)
  101. .where(AgentVersion.tenant_id == tenant_id)
  102. .where(AgentVersion.agent_id == agent_id)
  103. .order_by(AgentVersion.version_no.desc())
  104. )
  105. return list(self.db.scalars(stmt))
  106. def get_by_id(self, *, tenant_id: str, agent_version_id: str) -> AgentVersion | None:
  107. stmt = (
  108. select(AgentVersion)
  109. .where(AgentVersion.tenant_id == tenant_id)
  110. .where(AgentVersion.id == agent_version_id)
  111. )
  112. return self.db.scalar(stmt)
  113. def get_latest_published(self, *, tenant_id: str, agent_id: str) -> AgentVersion | None:
  114. stmt = (
  115. select(AgentVersion)
  116. .where(AgentVersion.tenant_id == tenant_id)
  117. .where(AgentVersion.agent_id == agent_id)
  118. .where(AgentVersion.status == "published")
  119. .order_by(AgentVersion.version_no.desc())
  120. .limit(1)
  121. )
  122. return self.db.scalar(stmt)
  123. def _next_version_no(self, agent_id: str) -> int:
  124. stmt = select(func.max(AgentVersion.version_no)).where(AgentVersion.agent_id == agent_id)
  125. current_max = self.db.scalar(stmt)
  126. return (current_max or 0) + 1
  127. class AgentRunRepository:
  128. def __init__(self, db: Session) -> None:
  129. self.db = db
  130. def create(
  131. self,
  132. *,
  133. tenant_id: str,
  134. agent_id: str,
  135. agent_version_id: str,
  136. session_id: str | None,
  137. input_text: str | None,
  138. input_json: dict[str, JSONValue] | None,
  139. ) -> AgentRun:
  140. entity = AgentRun(
  141. tenant_id=tenant_id,
  142. agent_id=agent_id,
  143. agent_version_id=agent_version_id,
  144. session_id=session_id,
  145. input_text=input_text,
  146. input_json=input_json,
  147. status="queued",
  148. )
  149. self.db.add(entity)
  150. self.db.commit()
  151. self.db.refresh(entity)
  152. return entity
  153. def list_by_scope(
  154. self,
  155. *,
  156. tenant_id: str,
  157. agent_id: str | None = None,
  158. session_id: str | None = None,
  159. ) -> list[AgentRun]:
  160. stmt = select(AgentRun).where(AgentRun.tenant_id == tenant_id)
  161. if agent_id is not None:
  162. stmt = stmt.where(AgentRun.agent_id == agent_id)
  163. if session_id is not None:
  164. stmt = stmt.where(AgentRun.session_id == session_id)
  165. stmt = stmt.order_by(AgentRun.created_time.desc())
  166. return list(self.db.scalars(stmt))
  167. def get_by_id(self, *, tenant_id: str, agent_run_id: str) -> AgentRun | None:
  168. stmt = (
  169. select(AgentRun)
  170. .where(AgentRun.tenant_id == tenant_id)
  171. .where(AgentRun.id == agent_run_id)
  172. )
  173. return self.db.scalar(stmt)
  174. def update_status(
  175. self,
  176. *,
  177. agent_run_id: str,
  178. status: AgentRunStatus,
  179. worker_key: str | None = None,
  180. output_text: str | None = None,
  181. output_json: dict[str, JSONValue] | None = None,
  182. error_code: str | None = None,
  183. error_message: str | None = None,
  184. ) -> AgentRun | None:
  185. entity = self.db.get(AgentRun, agent_run_id)
  186. if entity is None:
  187. return None
  188. now = datetime.utcnow()
  189. entity.status = status
  190. entity.worker_key = worker_key
  191. entity.output_text = output_text
  192. entity.output_json = output_json
  193. entity.error_code = error_code
  194. entity.error_message = error_message
  195. if status == "running" and entity.started_time is None:
  196. entity.started_time = now
  197. if status in {"completed", "failed", "cancelled"}:
  198. entity.finished_time = now
  199. self.db.commit()
  200. self.db.refresh(entity)
  201. return entity