repositories.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. from datetime import datetime
  2. from sqlalchemy import delete, select
  3. from sqlalchemy.orm import Session
  4. from core_domain import (
  5. AgentRunStatus,
  6. AgentStatus,
  7. AgentToolInvocationStatus)
  8. from core_shared import JSONValue
  9. from app.db.models import AgentDefinition, AgentRun, AgentToolInvocation, AgentConfig
  10. class AgentDefinitionRepository:
  11. def __init__(self, db: Session) -> None:
  12. self.db = db
  13. def create(
  14. self,
  15. *,
  16. code: str,
  17. name: str,
  18. description: str | None,
  19. agent_type: str,
  20. owner_user_id: str | None,
  21. metadata_json: dict[str, JSONValue] | None) -> AgentDefinition:
  22. entity = AgentDefinition(
  23. code=code,
  24. name=name,
  25. description=description,
  26. agent_type=agent_type,
  27. owner_user_id=owner_user_id,
  28. metadata_json=metadata_json)
  29. self.db.add(entity)
  30. self.db.commit()
  31. self.db.refresh(entity)
  32. return entity
  33. def list_all(self) -> list[AgentDefinition]:
  34. stmt = (
  35. select(AgentDefinition)
  36. .order_by(AgentDefinition.created_time.desc())
  37. )
  38. return list(self.db.scalars(stmt))
  39. def get_by_id(self, *, agent_id: str) -> AgentDefinition | None:
  40. stmt = (
  41. select(AgentDefinition)
  42. .where(AgentDefinition.id == agent_id)
  43. )
  44. return self.db.scalar(stmt)
  45. def delete(self, *, agent_id: str) -> AgentDefinition | None:
  46. entity = self.get_by_id(agent_id=agent_id)
  47. if entity is None:
  48. return None
  49. self.db.delete(entity)
  50. self.db.commit()
  51. return entity
  52. def update(
  53. self,
  54. *,
  55. agent_id: str,
  56. name: str | None,
  57. description: str | None,
  58. metadata_json: dict[str, JSONValue] | None) -> AgentDefinition | None:
  59. entity = self.get_by_id(agent_id=agent_id)
  60. if entity is None:
  61. return None
  62. if name is not None:
  63. entity.name = name
  64. if description is not None:
  65. entity.description = description
  66. if metadata_json is not None:
  67. entity.metadata_json = metadata_json
  68. self.db.commit()
  69. self.db.refresh(entity)
  70. return entity
  71. def update_status(
  72. self,
  73. *,
  74. agent_id: str,
  75. status: AgentStatus) -> AgentDefinition | None:
  76. entity = self.get_by_id(agent_id=agent_id)
  77. if entity is None:
  78. return None
  79. entity.status = status
  80. self.db.commit()
  81. self.db.refresh(entity)
  82. return entity
  83. class AgentConfigRepository:
  84. def __init__(self, db: Session) -> None:
  85. self.db = db
  86. def create(
  87. self,
  88. *,
  89. agent_id: str,
  90. role: str,
  91. goal: str | None,
  92. system_prompt: str,
  93. model_config_json: dict[str, JSONValue],
  94. memory_policy_json: dict[str, JSONValue],
  95. tool_refs_json: list[dict[str, JSONValue]],
  96. skill_refs_json: list[dict[str, JSONValue]]) -> AgentConfig:
  97. entity = AgentConfig(
  98. agent_id=agent_id,
  99. role=role,
  100. goal=goal,
  101. system_prompt=system_prompt,
  102. model_config_json=model_config_json,
  103. memory_policy_json=memory_policy_json,
  104. tool_refs_json=tool_refs_json,
  105. skill_refs_json=skill_refs_json)
  106. self.db.add(entity)
  107. self.db.commit()
  108. self.db.refresh(entity)
  109. return entity
  110. def list_by_agent(self, *, agent_id: str) -> list[AgentConfig]:
  111. stmt = (
  112. select(AgentConfig)
  113. .where(AgentConfig.agent_id == agent_id)
  114. .order_by(AgentConfig.created_time.desc())
  115. )
  116. return list(self.db.scalars(stmt))
  117. def get_by_id(self, *, agent_config_id: str) -> AgentConfig | None:
  118. stmt = (
  119. select(AgentConfig)
  120. .where(AgentConfig.id == agent_config_id)
  121. )
  122. return self.db.scalar(stmt)
  123. def get_latest_by_agent(self, *, agent_id: str) -> AgentConfig | None:
  124. stmt = (
  125. select(AgentConfig)
  126. .where(AgentConfig.agent_id == agent_id)
  127. .order_by(AgentConfig.created_time.desc())
  128. .limit(1)
  129. )
  130. return self.db.scalar(stmt)
  131. def delete_by_agent(self, *, agent_id: str) -> int:
  132. result = self.db.execute(
  133. delete(AgentConfig).where(AgentConfig.agent_id == agent_id))
  134. self.db.commit()
  135. return int(result.rowcount or 0)
  136. class AgentRunRepository:
  137. def __init__(self, db: Session) -> None:
  138. self.db = db
  139. def create(
  140. self,
  141. *,
  142. agent_id: str,
  143. agent_config_id: str,
  144. session_id: str | None,
  145. input_text: str | None,
  146. input_json: dict[str, JSONValue] | None) -> AgentRun:
  147. now = datetime.utcnow()
  148. entity = AgentRun(
  149. agent_id=agent_id,
  150. agent_config_id=agent_config_id,
  151. session_id=session_id,
  152. input_text=input_text,
  153. input_json=input_json,
  154. status="queued",
  155. queued_time=now)
  156. self.db.add(entity)
  157. self.db.commit()
  158. self.db.refresh(entity)
  159. return entity
  160. def list_by_scope(
  161. self,
  162. *,
  163. agent_id: str | None = None,
  164. session_id: str | None = None) -> list[AgentRun]:
  165. stmt = select(AgentRun)
  166. if agent_id is not None:
  167. stmt = stmt.where(AgentRun.agent_id == agent_id)
  168. if session_id is not None:
  169. stmt = stmt.where(AgentRun.session_id == session_id)
  170. stmt = stmt.order_by(AgentRun.created_time.desc())
  171. return list(self.db.scalars(stmt))
  172. def get_by_id(self, *, agent_run_id: str) -> AgentRun | None:
  173. stmt = (
  174. select(AgentRun)
  175. .where(AgentRun.id == agent_run_id)
  176. )
  177. return self.db.scalar(stmt)
  178. def delete_by_agent(self, *, agent_id: str) -> int:
  179. result = self.db.execute(
  180. delete(AgentRun).where(AgentRun.agent_id == agent_id))
  181. self.db.commit()
  182. return int(result.rowcount or 0)
  183. def claim_next_queued(
  184. self,
  185. *,
  186. worker_key: str,
  187. lease_expire_time: datetime) -> AgentRun | None:
  188. stmt = (
  189. select(AgentRun)
  190. .where(AgentRun.status == "queued")
  191. .order_by(AgentRun.created_time.asc())
  192. .with_for_update(skip_locked=True)
  193. .limit(1)
  194. )
  195. entity = self.db.scalar(stmt)
  196. if entity is None:
  197. return None
  198. now = datetime.utcnow()
  199. entity.status = "running"
  200. entity.worker_key = worker_key
  201. entity.started_time = entity.started_time or now
  202. entity.lease_expire_time = lease_expire_time
  203. self.db.commit()
  204. self.db.refresh(entity)
  205. return entity
  206. def release_expired_leases(self, *, now_time: datetime, max_items: int = 100) -> int:
  207. stmt = (
  208. select(AgentRun)
  209. .where(AgentRun.status == "running")
  210. .where(AgentRun.lease_expire_time.is_not(None))
  211. .where(AgentRun.lease_expire_time <= now_time)
  212. .order_by(AgentRun.lease_expire_time.asc())
  213. .limit(max_items)
  214. )
  215. entities = list(self.db.scalars(stmt))
  216. for entity in entities:
  217. entity.status = "queued"
  218. entity.worker_key = None
  219. entity.lease_expire_time = None
  220. entity.queued_time = now_time
  221. entity.started_time = None
  222. entity.finished_time = None
  223. if entities:
  224. self.db.commit()
  225. return len(entities)
  226. def update_status(
  227. self,
  228. *,
  229. agent_run_id: str,
  230. status: AgentRunStatus,
  231. worker_key: str | None = None,
  232. output_text: str | None = None,
  233. output_json: dict[str, JSONValue] | None = None,
  234. error_code: str | None = None,
  235. error_message: str | None = None) -> AgentRun | None:
  236. entity = self.db.get(AgentRun, agent_run_id)
  237. if entity is None:
  238. return None
  239. now = datetime.utcnow()
  240. entity.status = status
  241. entity.worker_key = worker_key
  242. entity.output_text = output_text
  243. entity.output_json = output_json
  244. entity.error_code = error_code
  245. entity.error_message = error_message
  246. if status == "running" and entity.started_time is None:
  247. entity.started_time = now
  248. if status in {"completed", "failed", "cancelled"}:
  249. entity.finished_time = now
  250. entity.lease_expire_time = None
  251. self.db.commit()
  252. self.db.refresh(entity)
  253. return entity
  254. class AgentToolInvocationRepository:
  255. def __init__(self, db: Session) -> None:
  256. self.db = db
  257. def create(
  258. self,
  259. *,
  260. agent_run_id: str,
  261. agent_id: str,
  262. agent_config_id: str,
  263. tool_code: str | None,
  264. tool_binding_id: str | None,
  265. status: AgentToolInvocationStatus,
  266. reason: str | None = None,
  267. input_json: dict[str, JSONValue] | None = None) -> AgentToolInvocation:
  268. entity = AgentToolInvocation(
  269. agent_run_id=agent_run_id,
  270. agent_id=agent_id,
  271. agent_config_id=agent_config_id,
  272. tool_code=tool_code,
  273. tool_binding_id=tool_binding_id,
  274. status=status,
  275. reason=reason,
  276. input_json=input_json or {})
  277. self.db.add(entity)
  278. self.db.commit()
  279. self.db.refresh(entity)
  280. return entity
  281. def list_by_run(
  282. self,
  283. *,
  284. agent_run_id: str) -> list[AgentToolInvocation]:
  285. stmt = (
  286. select(AgentToolInvocation)
  287. .where(AgentToolInvocation.agent_run_id == agent_run_id)
  288. .order_by(AgentToolInvocation.created_time.asc())
  289. )
  290. return list(self.db.scalars(stmt))
  291. def delete_by_run(self, *, agent_run_id: str) -> int:
  292. result = self.db.execute(
  293. delete(AgentToolInvocation).where(
  294. AgentToolInvocation.agent_run_id == agent_run_id))
  295. self.db.commit()
  296. return int(result.rowcount or 0)
  297. def delete_by_agent(self, *, agent_id: str) -> int:
  298. result = self.db.execute(
  299. delete(AgentToolInvocation).where(
  300. AgentToolInvocation.agent_id == agent_id))
  301. self.db.commit()
  302. return int(result.rowcount or 0)
  303. def update_status(
  304. self,
  305. *,
  306. invocation_id: str,
  307. status: AgentToolInvocationStatus,
  308. reason: str | None = None,
  309. output_text: str | None = None,
  310. output_json: dict[str, JSONValue] | None = None,
  311. error_message: str | None = None) -> AgentToolInvocation | None:
  312. entity = self.db.get(AgentToolInvocation, invocation_id)
  313. if entity is None:
  314. return None
  315. now = datetime.utcnow()
  316. entity.status = status
  317. entity.reason = reason
  318. entity.output_text = output_text
  319. entity.output_json = output_json
  320. entity.error_message = error_message
  321. if status == "running" and entity.started_time is None:
  322. entity.started_time = now
  323. if status in {"completed", "failed", "skipped"}:
  324. if entity.started_time is None:
  325. entity.started_time = now
  326. entity.finished_time = now
  327. self.db.commit()
  328. self.db.refresh(entity)
  329. return entity