repositories.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. from datetime import datetime
  2. from sqlalchemy import func, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import (
  5. AgentRunStatus,
  6. AgentStatus,
  7. AgentToolInvocationStatus,
  8. AgentVersionStatus)
  9. from core_shared import JSONValue
  10. from app.db.models import AgentDefinition, AgentRun, AgentToolInvocation, AgentVersion
  11. class AgentDefinitionRepository:
  12. def __init__(self, db: Session) -> None:
  13. self.db = db
  14. def create(
  15. self,
  16. *,
  17. code: str,
  18. name: str,
  19. description: str | None,
  20. agent_type: str,
  21. owner_user_id: str | None,
  22. metadata_json: dict[str, JSONValue] | None) -> AgentDefinition:
  23. entity = AgentDefinition(
  24. code=code,
  25. name=name,
  26. description=description,
  27. agent_type=agent_type,
  28. owner_user_id=owner_user_id,
  29. metadata_json=metadata_json)
  30. self.db.add(entity)
  31. self.db.commit()
  32. self.db.refresh(entity)
  33. return entity
  34. def list_all(self) -> list[AgentDefinition]:
  35. stmt = (
  36. select(AgentDefinition)
  37. .order_by(AgentDefinition.created_time.desc())
  38. )
  39. return list(self.db.scalars(stmt))
  40. def get_by_id(self, *, agent_id: str) -> AgentDefinition | None:
  41. stmt = (
  42. select(AgentDefinition)
  43. .where(AgentDefinition.id == agent_id)
  44. )
  45. return self.db.scalar(stmt)
  46. def update_status(
  47. self,
  48. *,
  49. agent_id: str,
  50. status: AgentStatus) -> AgentDefinition | None:
  51. entity = self.get_by_id(agent_id=agent_id)
  52. if entity is None:
  53. return None
  54. entity.status = status
  55. self.db.commit()
  56. self.db.refresh(entity)
  57. return entity
  58. class AgentVersionRepository:
  59. def __init__(self, db: Session) -> None:
  60. self.db = db
  61. def create(
  62. self,
  63. *,
  64. agent_id: str,
  65. status: AgentVersionStatus,
  66. role: str,
  67. goal: str | None,
  68. system_prompt: str,
  69. model_config_json: dict[str, JSONValue],
  70. memory_policy_json: dict[str, JSONValue],
  71. tool_refs_json: list[dict[str, JSONValue]],
  72. skill_refs_json: list[dict[str, JSONValue]]) -> AgentVersion:
  73. version_no = self._next_version_no(agent_id)
  74. entity = AgentVersion(
  75. agent_id=agent_id,
  76. version_no=version_no,
  77. status=status,
  78. role=role,
  79. goal=goal,
  80. system_prompt=system_prompt,
  81. model_config_json=model_config_json,
  82. memory_policy_json=memory_policy_json,
  83. tool_refs_json=tool_refs_json,
  84. skill_refs_json=skill_refs_json,
  85. published_time=datetime.utcnow() if status == "published" else None)
  86. self.db.add(entity)
  87. self.db.commit()
  88. self.db.refresh(entity)
  89. return entity
  90. def list_by_agent(self, *, agent_id: str) -> list[AgentVersion]:
  91. stmt = (
  92. select(AgentVersion)
  93. .where(AgentVersion.agent_id == agent_id)
  94. .order_by(AgentVersion.version_no.desc())
  95. )
  96. return list(self.db.scalars(stmt))
  97. def get_by_id(self, *, agent_version_id: str) -> AgentVersion | None:
  98. stmt = (
  99. select(AgentVersion)
  100. .where(AgentVersion.id == agent_version_id)
  101. )
  102. return self.db.scalar(stmt)
  103. def get_latest_published(self, *, agent_id: str) -> AgentVersion | None:
  104. stmt = (
  105. select(AgentVersion)
  106. .where(AgentVersion.agent_id == agent_id)
  107. .where(AgentVersion.status == "published")
  108. .order_by(AgentVersion.version_no.desc())
  109. .limit(1)
  110. )
  111. return self.db.scalar(stmt)
  112. def _next_version_no(self, agent_id: str) -> int:
  113. stmt = select(func.max(AgentVersion.version_no)).where(AgentVersion.agent_id == agent_id)
  114. current_max = self.db.scalar(stmt)
  115. return (current_max or 0) + 1
  116. class AgentRunRepository:
  117. def __init__(self, db: Session) -> None:
  118. self.db = db
  119. def create(
  120. self,
  121. *,
  122. agent_id: str,
  123. agent_version_id: str,
  124. session_id: str | None,
  125. input_text: str | None,
  126. input_json: dict[str, JSONValue] | None) -> AgentRun:
  127. now = datetime.utcnow()
  128. entity = AgentRun(
  129. agent_id=agent_id,
  130. agent_version_id=agent_version_id,
  131. session_id=session_id,
  132. input_text=input_text,
  133. input_json=input_json,
  134. status="queued",
  135. queued_time=now)
  136. self.db.add(entity)
  137. self.db.commit()
  138. self.db.refresh(entity)
  139. return entity
  140. def list_by_scope(
  141. self,
  142. *,
  143. agent_id: str | None = None,
  144. session_id: str | None = None) -> list[AgentRun]:
  145. stmt = select(AgentRun)
  146. if agent_id is not None:
  147. stmt = stmt.where(AgentRun.agent_id == agent_id)
  148. if session_id is not None:
  149. stmt = stmt.where(AgentRun.session_id == session_id)
  150. stmt = stmt.order_by(AgentRun.created_time.desc())
  151. return list(self.db.scalars(stmt))
  152. def get_by_id(self, *, agent_run_id: str) -> AgentRun | None:
  153. stmt = (
  154. select(AgentRun)
  155. .where(AgentRun.id == agent_run_id)
  156. )
  157. return self.db.scalar(stmt)
  158. def claim_next_queued(
  159. self,
  160. *,
  161. worker_key: str,
  162. lease_expire_time: datetime) -> AgentRun | None:
  163. stmt = (
  164. select(AgentRun)
  165. .where(AgentRun.status == "queued")
  166. .order_by(AgentRun.created_time.asc())
  167. .with_for_update(skip_locked=True)
  168. .limit(1)
  169. )
  170. entity = self.db.scalar(stmt)
  171. if entity is None:
  172. return None
  173. now = datetime.utcnow()
  174. entity.status = "running"
  175. entity.worker_key = worker_key
  176. entity.started_time = entity.started_time or now
  177. entity.lease_expire_time = lease_expire_time
  178. self.db.commit()
  179. self.db.refresh(entity)
  180. return entity
  181. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  182. stmt = (
  183. select(AgentRun)
  184. .where(AgentRun.status == "running")
  185. .where(AgentRun.lease_expire_time.is_not(None))
  186. .where(AgentRun.lease_expire_time <= now_time)
  187. .order_by(AgentRun.lease_expire_time.asc())
  188. .limit(max_items)
  189. )
  190. entities = list(self.db.scalars(stmt))
  191. for entity in entities:
  192. entity.status = "queued"
  193. entity.worker_key = None
  194. entity.lease_expire_time = None
  195. entity.queued_time = now_time
  196. entity.started_time = None
  197. entity.finished_time = None
  198. if entities:
  199. self.db.commit()
  200. return len(entities)
  201. def update_status(
  202. self,
  203. *,
  204. agent_run_id: str,
  205. status: AgentRunStatus,
  206. worker_key: str | None = None,
  207. output_text: str | None = None,
  208. output_json: dict[str, JSONValue] | None = None,
  209. error_code: str | None = None,
  210. error_message: str | None = None) -> AgentRun | None:
  211. entity = self.db.get(AgentRun, agent_run_id)
  212. if entity is None:
  213. return None
  214. now = datetime.utcnow()
  215. entity.status = status
  216. entity.worker_key = worker_key
  217. entity.output_text = output_text
  218. entity.output_json = output_json
  219. entity.error_code = error_code
  220. entity.error_message = error_message
  221. if status == "running" and entity.started_time is None:
  222. entity.started_time = now
  223. if status in {"completed", "failed", "cancelled"}:
  224. entity.finished_time = now
  225. entity.lease_expire_time = None
  226. self.db.commit()
  227. self.db.refresh(entity)
  228. return entity
  229. class AgentToolInvocationRepository:
  230. def __init__(self, db: Session) -> None:
  231. self.db = db
  232. def create(
  233. self,
  234. *,
  235. agent_run_id: str,
  236. agent_id: str,
  237. agent_version_id: str,
  238. tool_code: str | None,
  239. tool_binding_id: str | None,
  240. status: AgentToolInvocationStatus,
  241. reason: str | None = None,
  242. input_json: dict[str, JSONValue] | None = None) -> AgentToolInvocation:
  243. entity = AgentToolInvocation(
  244. agent_run_id=agent_run_id,
  245. agent_id=agent_id,
  246. agent_version_id=agent_version_id,
  247. tool_code=tool_code,
  248. tool_binding_id=tool_binding_id,
  249. status=status,
  250. reason=reason,
  251. input_json=input_json or {})
  252. self.db.add(entity)
  253. self.db.commit()
  254. self.db.refresh(entity)
  255. return entity
  256. def list_by_run(
  257. self,
  258. *,
  259. agent_run_id: str) -> list[AgentToolInvocation]:
  260. stmt = (
  261. select(AgentToolInvocation)
  262. .where(AgentToolInvocation.agent_run_id == agent_run_id)
  263. .order_by(AgentToolInvocation.created_time.asc())
  264. )
  265. return list(self.db.scalars(stmt))
  266. def update_status(
  267. self,
  268. *,
  269. invocation_id: str,
  270. status: AgentToolInvocationStatus,
  271. reason: str | None = None,
  272. output_text: str | None = None,
  273. output_json: dict[str, JSONValue] | None = None,
  274. error_message: str | None = None) -> AgentToolInvocation | None:
  275. entity = self.db.get(AgentToolInvocation, invocation_id)
  276. if entity is None:
  277. return None
  278. now = datetime.utcnow()
  279. entity.status = status
  280. entity.reason = reason
  281. entity.output_text = output_text
  282. entity.output_json = output_json
  283. entity.error_message = error_message
  284. if status == "running" and entity.started_time is None:
  285. entity.started_time = now
  286. if status in {"completed", "failed", "skipped"}:
  287. if entity.started_time is None:
  288. entity.started_time = now
  289. entity.finished_time = now
  290. self.db.commit()
  291. self.db.refresh(entity)
  292. return entity