repositories.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. from datetime import datetime
  2. from sqlalchemy import func, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import (
  5. AgentRunStatus,
  6. AgentStatus,
  7. AgentToolInvocationStatus,
  8. AgentVersionStatus,
  9. )
  10. from core_shared import JSONValue
  11. from app.db.models import AgentDefinition, AgentRun, AgentToolInvocation, AgentVersion
  12. class AgentDefinitionRepository:
  13. def __init__(self, db: Session) -> None:
  14. self.db = db
  15. def create(
  16. self,
  17. *,
  18. tenant_id: str,
  19. code: str,
  20. name: str,
  21. description: str | None,
  22. agent_type: str,
  23. owner_user_id: str | None,
  24. metadata_json: dict[str, JSONValue] | None,
  25. ) -> AgentDefinition:
  26. entity = AgentDefinition(
  27. tenant_id=tenant_id,
  28. code=code,
  29. name=name,
  30. description=description,
  31. agent_type=agent_type,
  32. owner_user_id=owner_user_id,
  33. metadata_json=metadata_json,
  34. )
  35. self.db.add(entity)
  36. self.db.commit()
  37. self.db.refresh(entity)
  38. return entity
  39. def list_by_tenant(self, *, tenant_id: str) -> list[AgentDefinition]:
  40. stmt = (
  41. select(AgentDefinition)
  42. .where(AgentDefinition.tenant_id == tenant_id)
  43. .order_by(AgentDefinition.created_time.desc())
  44. )
  45. return list(self.db.scalars(stmt))
  46. def get_by_id(self, *, tenant_id: str, agent_id: str) -> AgentDefinition | None:
  47. stmt = (
  48. select(AgentDefinition)
  49. .where(AgentDefinition.tenant_id == tenant_id)
  50. .where(AgentDefinition.id == agent_id)
  51. )
  52. return self.db.scalar(stmt)
  53. def update_status(
  54. self,
  55. *,
  56. tenant_id: str,
  57. agent_id: str,
  58. status: AgentStatus,
  59. ) -> AgentDefinition | None:
  60. entity = self.get_by_id(tenant_id=tenant_id, agent_id=agent_id)
  61. if entity is None:
  62. return None
  63. entity.status = status
  64. self.db.commit()
  65. self.db.refresh(entity)
  66. return entity
  67. class AgentVersionRepository:
  68. def __init__(self, db: Session) -> None:
  69. self.db = db
  70. def create(
  71. self,
  72. *,
  73. tenant_id: str,
  74. agent_id: str,
  75. status: AgentVersionStatus,
  76. role: str,
  77. goal: str | None,
  78. system_prompt: str,
  79. model_config_json: dict[str, JSONValue],
  80. memory_policy_json: dict[str, JSONValue],
  81. tool_refs_json: list[dict[str, JSONValue]],
  82. skill_refs_json: list[dict[str, JSONValue]],
  83. ) -> AgentVersion:
  84. version_no = self._next_version_no(agent_id)
  85. entity = AgentVersion(
  86. tenant_id=tenant_id,
  87. agent_id=agent_id,
  88. version_no=version_no,
  89. status=status,
  90. role=role,
  91. goal=goal,
  92. system_prompt=system_prompt,
  93. model_config_json=model_config_json,
  94. memory_policy_json=memory_policy_json,
  95. tool_refs_json=tool_refs_json,
  96. skill_refs_json=skill_refs_json,
  97. published_time=datetime.utcnow() if status == "published" else None,
  98. )
  99. self.db.add(entity)
  100. self.db.commit()
  101. self.db.refresh(entity)
  102. return entity
  103. def list_by_agent(self, *, tenant_id: str, agent_id: str) -> list[AgentVersion]:
  104. stmt = (
  105. select(AgentVersion)
  106. .where(AgentVersion.tenant_id == tenant_id)
  107. .where(AgentVersion.agent_id == agent_id)
  108. .order_by(AgentVersion.version_no.desc())
  109. )
  110. return list(self.db.scalars(stmt))
  111. def get_by_id(self, *, tenant_id: str, agent_version_id: str) -> AgentVersion | None:
  112. stmt = (
  113. select(AgentVersion)
  114. .where(AgentVersion.tenant_id == tenant_id)
  115. .where(AgentVersion.id == agent_version_id)
  116. )
  117. return self.db.scalar(stmt)
  118. def get_latest_published(self, *, tenant_id: str, agent_id: str) -> AgentVersion | None:
  119. stmt = (
  120. select(AgentVersion)
  121. .where(AgentVersion.tenant_id == tenant_id)
  122. .where(AgentVersion.agent_id == agent_id)
  123. .where(AgentVersion.status == "published")
  124. .order_by(AgentVersion.version_no.desc())
  125. .limit(1)
  126. )
  127. return self.db.scalar(stmt)
  128. def _next_version_no(self, agent_id: str) -> int:
  129. stmt = select(func.max(AgentVersion.version_no)).where(AgentVersion.agent_id == agent_id)
  130. current_max = self.db.scalar(stmt)
  131. return (current_max or 0) + 1
  132. class AgentRunRepository:
  133. def __init__(self, db: Session) -> None:
  134. self.db = db
  135. def create(
  136. self,
  137. *,
  138. tenant_id: str,
  139. agent_id: str,
  140. agent_version_id: str,
  141. session_id: str | None,
  142. input_text: str | None,
  143. input_json: dict[str, JSONValue] | None,
  144. ) -> AgentRun:
  145. now = datetime.utcnow()
  146. entity = AgentRun(
  147. tenant_id=tenant_id,
  148. agent_id=agent_id,
  149. agent_version_id=agent_version_id,
  150. session_id=session_id,
  151. input_text=input_text,
  152. input_json=input_json,
  153. status="queued",
  154. queued_time=now,
  155. )
  156. self.db.add(entity)
  157. self.db.commit()
  158. self.db.refresh(entity)
  159. return entity
  160. def list_by_scope(
  161. self,
  162. *,
  163. tenant_id: str,
  164. agent_id: str | None = None,
  165. session_id: str | None = None,
  166. ) -> list[AgentRun]:
  167. stmt = select(AgentRun).where(AgentRun.tenant_id == tenant_id)
  168. if agent_id is not None:
  169. stmt = stmt.where(AgentRun.agent_id == agent_id)
  170. if session_id is not None:
  171. stmt = stmt.where(AgentRun.session_id == session_id)
  172. stmt = stmt.order_by(AgentRun.created_time.desc())
  173. return list(self.db.scalars(stmt))
  174. def get_by_id(self, *, tenant_id: str, agent_run_id: str) -> AgentRun | None:
  175. stmt = (
  176. select(AgentRun)
  177. .where(AgentRun.tenant_id == tenant_id)
  178. .where(AgentRun.id == agent_run_id)
  179. )
  180. return self.db.scalar(stmt)
  181. def claim_next_queued(
  182. self,
  183. *,
  184. worker_key: str,
  185. lease_expire_time: datetime,
  186. ) -> AgentRun | None:
  187. stmt = (
  188. select(AgentRun)
  189. .where(AgentRun.status == "queued")
  190. .order_by(AgentRun.created_time.asc())
  191. .with_for_update(skip_locked=True)
  192. .limit(1)
  193. )
  194. entity = self.db.scalar(stmt)
  195. if entity is None:
  196. return None
  197. now = datetime.utcnow()
  198. entity.status = "running"
  199. entity.worker_key = worker_key
  200. entity.started_time = entity.started_time or now
  201. entity.lease_expire_time = lease_expire_time
  202. self.db.commit()
  203. self.db.refresh(entity)
  204. return entity
  205. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  206. stmt = (
  207. select(AgentRun)
  208. .where(AgentRun.status == "running")
  209. .where(AgentRun.lease_expire_time.is_not(None))
  210. .where(AgentRun.lease_expire_time <= now_time)
  211. .order_by(AgentRun.lease_expire_time.asc())
  212. .limit(max_items)
  213. )
  214. entities = list(self.db.scalars(stmt))
  215. for entity in entities:
  216. entity.status = "queued"
  217. entity.worker_key = None
  218. entity.lease_expire_time = None
  219. entity.queued_time = now_time
  220. entity.started_time = None
  221. entity.finished_time = None
  222. if entities:
  223. self.db.commit()
  224. return len(entities)
  225. def update_status(
  226. self,
  227. *,
  228. agent_run_id: str,
  229. status: AgentRunStatus,
  230. worker_key: str | None = None,
  231. output_text: str | None = None,
  232. output_json: dict[str, JSONValue] | None = None,
  233. error_code: str | None = None,
  234. error_message: str | None = None,
  235. ) -> AgentRun | None:
  236. entity = self.db.get(AgentRun, agent_run_id)
  237. if entity is None:
  238. return None
  239. now = datetime.utcnow()
  240. entity.status = status
  241. entity.worker_key = worker_key
  242. entity.output_text = output_text
  243. entity.output_json = output_json
  244. entity.error_code = error_code
  245. entity.error_message = error_message
  246. if status == "running" and entity.started_time is None:
  247. entity.started_time = now
  248. if status in {"completed", "failed", "cancelled"}:
  249. entity.finished_time = now
  250. entity.lease_expire_time = None
  251. self.db.commit()
  252. self.db.refresh(entity)
  253. return entity
  254. class AgentToolInvocationRepository:
  255. def __init__(self, db: Session) -> None:
  256. self.db = db
  257. def create(
  258. self,
  259. *,
  260. tenant_id: str,
  261. agent_run_id: str,
  262. agent_id: str,
  263. agent_version_id: str,
  264. tool_code: str | None,
  265. tool_binding_id: str | None,
  266. status: AgentToolInvocationStatus,
  267. reason: str | None = None,
  268. input_json: dict[str, JSONValue] | None = None,
  269. ) -> AgentToolInvocation:
  270. entity = AgentToolInvocation(
  271. tenant_id=tenant_id,
  272. agent_run_id=agent_run_id,
  273. agent_id=agent_id,
  274. agent_version_id=agent_version_id,
  275. tool_code=tool_code,
  276. tool_binding_id=tool_binding_id,
  277. status=status,
  278. reason=reason,
  279. input_json=input_json or {},
  280. )
  281. self.db.add(entity)
  282. self.db.commit()
  283. self.db.refresh(entity)
  284. return entity
  285. def list_by_run(
  286. self,
  287. *,
  288. tenant_id: str,
  289. agent_run_id: str,
  290. ) -> list[AgentToolInvocation]:
  291. stmt = (
  292. select(AgentToolInvocation)
  293. .where(AgentToolInvocation.tenant_id == tenant_id)
  294. .where(AgentToolInvocation.agent_run_id == agent_run_id)
  295. .order_by(AgentToolInvocation.created_time.asc())
  296. )
  297. return list(self.db.scalars(stmt))
  298. def update_status(
  299. self,
  300. *,
  301. invocation_id: str,
  302. status: AgentToolInvocationStatus,
  303. reason: str | None = None,
  304. output_text: str | None = None,
  305. output_json: dict[str, JSONValue] | None = None,
  306. error_message: str | None = None,
  307. ) -> AgentToolInvocation | None:
  308. entity = self.db.get(AgentToolInvocation, invocation_id)
  309. if entity is None:
  310. return None
  311. now = datetime.utcnow()
  312. entity.status = status
  313. entity.reason = reason
  314. entity.output_text = output_text
  315. entity.output_json = output_json
  316. entity.error_message = error_message
  317. if status == "running" and entity.started_time is None:
  318. entity.started_time = now
  319. if status in {"completed", "failed", "skipped"}:
  320. if entity.started_time is None:
  321. entity.started_time = now
  322. entity.finished_time = now
  323. self.db.commit()
  324. self.db.refresh(entity)
  325. return entity