services.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. from datetime import datetime, timedelta
  2. from typing import cast
  3. from sqlalchemy.orm import Session
  4. from core_domain import (
  5. ChatCompletionRequestContract,
  6. ChatMessageContract,
  7. MemoryCreateContract,
  8. MemoryScopeType,
  9. MemorySearchRequestContract,
  10. MemorySearchResultContract,
  11. )
  12. from core_shared import JSONValue
  13. from app.bootstrap.settings import AgentServiceSettings
  14. from app.db.models import AgentDefinition, AgentRun, AgentVersion
  15. from app.domain.repositories import (
  16. AgentDefinitionRepository,
  17. AgentRunRepository,
  18. AgentVersionRepository,
  19. )
  20. from app.infrastructure.model_gateway_client import ModelGatewayClient, ModelGatewayClientError
  21. from app.infrastructure.memory_client import MemoryClient, MemoryClientError
  22. from app.schemas.agent import (
  23. AgentCreateRequest,
  24. AgentRunCreateRequest,
  25. AgentRunExecuteRequest,
  26. AgentRunStatusUpdateRequest,
  27. AgentStatusUpdateRequest,
  28. AgentVersionCreateRequest,
  29. )
  30. class AgentApplicationService:
  31. def __init__(
  32. self,
  33. *,
  34. agent_repository: AgentDefinitionRepository,
  35. agent_version_repository: AgentVersionRepository,
  36. agent_run_repository: AgentRunRepository,
  37. model_gateway_client: ModelGatewayClient | None = None,
  38. memory_client: MemoryClient | None = None,
  39. ) -> None:
  40. self.agent_repository = agent_repository
  41. self.agent_version_repository = agent_version_repository
  42. self.agent_run_repository = agent_run_repository
  43. self.model_gateway_client = model_gateway_client
  44. self.memory_client = memory_client
  45. def create_agent(self, payload: AgentCreateRequest) -> AgentDefinition:
  46. return self.agent_repository.create(
  47. tenant_id=payload.tenant_id,
  48. code=payload.code,
  49. name=payload.name,
  50. description=payload.description,
  51. agent_type=payload.agent_type,
  52. owner_user_id=payload.owner_user_id,
  53. metadata_json=payload.metadata_json,
  54. )
  55. def list_agents(self, *, tenant_id: str) -> list[AgentDefinition]:
  56. return self.agent_repository.list_by_tenant(tenant_id=tenant_id)
  57. def update_agent_status(
  58. self,
  59. *,
  60. agent_id: str,
  61. payload: AgentStatusUpdateRequest,
  62. ) -> AgentDefinition | None:
  63. return self.agent_repository.update_status(
  64. tenant_id=payload.tenant_id,
  65. agent_id=agent_id,
  66. status=payload.status,
  67. )
  68. def create_agent_version(self, payload: AgentVersionCreateRequest) -> AgentVersion:
  69. agent = self.agent_repository.get_by_id(
  70. tenant_id=payload.tenant_id,
  71. agent_id=payload.agent_id,
  72. )
  73. if agent is None:
  74. raise ValueError(f"agent not found: {payload.agent_id}")
  75. return self.agent_version_repository.create(
  76. tenant_id=payload.tenant_id,
  77. agent_id=payload.agent_id,
  78. status=payload.status,
  79. role=payload.role,
  80. goal=payload.goal,
  81. system_prompt=payload.system_prompt,
  82. model_config_json=payload.model_config_data.model_dump(mode="json"),
  83. memory_policy_json=payload.memory_policy.model_dump(mode="json"),
  84. tool_refs_json=[item.model_dump(mode="json") for item in payload.tool_refs],
  85. skill_refs_json=[item.model_dump(mode="json") for item in payload.skill_refs],
  86. )
  87. def list_agent_versions(self, *, tenant_id: str, agent_id: str) -> list[AgentVersion]:
  88. return self.agent_version_repository.list_by_agent(tenant_id=tenant_id, agent_id=agent_id)
  89. def create_agent_run(self, payload: AgentRunCreateRequest) -> AgentRun:
  90. agent_version = self._resolve_agent_version(
  91. tenant_id=payload.tenant_id,
  92. agent_id=payload.agent_id,
  93. agent_version_id=payload.agent_version_id,
  94. )
  95. if agent_version is None:
  96. raise ValueError("published agent version not found")
  97. return self.agent_run_repository.create(
  98. tenant_id=payload.tenant_id,
  99. agent_id=payload.agent_id,
  100. agent_version_id=agent_version.id,
  101. session_id=payload.session_id,
  102. input_text=payload.input_text,
  103. input_json=payload.input_json,
  104. )
  105. def list_agent_runs(
  106. self,
  107. *,
  108. tenant_id: str,
  109. agent_id: str | None = None,
  110. session_id: str | None = None,
  111. ) -> list[AgentRun]:
  112. return self.agent_run_repository.list_by_scope(
  113. tenant_id=tenant_id,
  114. agent_id=agent_id,
  115. session_id=session_id,
  116. )
  117. def update_agent_run_status(
  118. self,
  119. *,
  120. agent_run_id: str,
  121. payload: AgentRunStatusUpdateRequest,
  122. ) -> AgentRun | None:
  123. entity = self.agent_run_repository.get_by_id(
  124. tenant_id=payload.tenant_id,
  125. agent_run_id=agent_run_id,
  126. )
  127. if entity is None:
  128. return None
  129. return self.agent_run_repository.update_status(
  130. agent_run_id=agent_run_id,
  131. status=payload.status,
  132. worker_key=payload.worker_key,
  133. output_text=payload.output_text,
  134. output_json=payload.output_json,
  135. error_code=payload.error_code,
  136. error_message=payload.error_message,
  137. )
  138. def execute_agent_run(
  139. self,
  140. *,
  141. agent_run_id: str,
  142. payload: AgentRunExecuteRequest,
  143. ) -> AgentRun | None:
  144. agent_run = self.agent_run_repository.get_by_id(
  145. tenant_id=payload.tenant_id,
  146. agent_run_id=agent_run_id,
  147. )
  148. if agent_run is None:
  149. return None
  150. agent_version = self.agent_version_repository.get_by_id(
  151. tenant_id=payload.tenant_id,
  152. agent_version_id=agent_run.agent_version_id,
  153. )
  154. if agent_version is None:
  155. return self.agent_run_repository.update_status(
  156. agent_run_id=agent_run.id,
  157. status="failed",
  158. worker_key=payload.worker_key,
  159. error_code="agent_version_missing",
  160. error_message=f"agent version not found: {agent_run.agent_version_id}",
  161. )
  162. self.agent_run_repository.update_status(
  163. agent_run_id=agent_run.id,
  164. status="running",
  165. worker_key=payload.worker_key,
  166. )
  167. memory_results, memory_metadata = self._read_relevant_memories(
  168. agent_run=agent_run,
  169. agent_version=agent_version,
  170. )
  171. messages = self._build_chat_messages(
  172. agent_run=agent_run,
  173. agent_version=agent_version,
  174. memory_results=memory_results,
  175. )
  176. if payload.dry_run:
  177. return self.agent_run_repository.update_status(
  178. agent_run_id=agent_run.id,
  179. status="completed",
  180. worker_key=payload.worker_key,
  181. output_text=self._build_dry_run_output(
  182. agent_run=agent_run,
  183. agent_version=agent_version,
  184. ),
  185. output_json={
  186. "dry_run": True,
  187. "agent_version_id": agent_version.id,
  188. "message_count": len(messages),
  189. "messages": [message.model_dump(mode="json") for message in messages],
  190. **memory_metadata,
  191. },
  192. )
  193. if self.model_gateway_client is None:
  194. return self.agent_run_repository.update_status(
  195. agent_run_id=agent_run.id,
  196. status="failed",
  197. worker_key=payload.worker_key,
  198. error_code="model_gateway_missing",
  199. error_message="model gateway client is not configured",
  200. )
  201. try:
  202. response = self.model_gateway_client.create_chat_completion(
  203. ChatCompletionRequestContract(
  204. model=self._read_optional_string(agent_version.model_config_json, "model"),
  205. temperature=self._read_optional_float(
  206. agent_version.model_config_json,
  207. "temperature",
  208. ),
  209. max_tokens=self._read_optional_int(agent_version.model_config_json, "max_tokens"),
  210. messages=messages,
  211. metadata_json={
  212. "tenant_id": agent_run.tenant_id,
  213. "agent_id": agent_run.agent_id,
  214. "agent_version_id": agent_version.id,
  215. "agent_run_id": agent_run.id,
  216. },
  217. )
  218. )
  219. except ModelGatewayClientError as exc:
  220. return self.agent_run_repository.update_status(
  221. agent_run_id=agent_run.id,
  222. status="failed",
  223. worker_key=payload.worker_key,
  224. error_code="model_gateway_error",
  225. error_message=str(exc),
  226. )
  227. memory_write_metadata = self._write_interaction_memory(
  228. agent_run=agent_run,
  229. agent_version=agent_version,
  230. output_text=response.content,
  231. )
  232. return self.agent_run_repository.update_status(
  233. agent_run_id=agent_run.id,
  234. status="completed",
  235. worker_key=payload.worker_key,
  236. output_text=response.content,
  237. output_json={
  238. "dry_run": False,
  239. "agent_version_id": agent_version.id,
  240. "model": response.model,
  241. "finish_reason": response.finish_reason,
  242. "usage_json": response.usage_json,
  243. "raw_response_json": response.raw_response_json,
  244. **memory_metadata,
  245. **memory_write_metadata,
  246. },
  247. )
  248. def execute_next_claimed_agent_run(
  249. self,
  250. *,
  251. worker_key: str,
  252. lease_seconds: int,
  253. dry_run: bool,
  254. ) -> tuple[AgentRun, int] | None:
  255. released_lease_count = self.agent_run_repository.release_expired_leases(
  256. now_time=datetime.utcnow(),
  257. )
  258. claimed_agent_run = self.agent_run_repository.claim_next_queued(
  259. worker_key=worker_key,
  260. lease_expire_time=datetime.utcnow() + timedelta(seconds=lease_seconds),
  261. )
  262. if claimed_agent_run is None:
  263. return None
  264. result = self.execute_agent_run(
  265. agent_run_id=claimed_agent_run.id,
  266. payload=AgentRunExecuteRequest(
  267. tenant_id=claimed_agent_run.tenant_id,
  268. worker_key=worker_key,
  269. dry_run=dry_run,
  270. ),
  271. )
  272. if result is None:
  273. return None
  274. return result, released_lease_count
  275. def _resolve_agent_version(
  276. self,
  277. *,
  278. tenant_id: str,
  279. agent_id: str,
  280. agent_version_id: str | None,
  281. ) -> AgentVersion | None:
  282. if agent_version_id is not None:
  283. return self.agent_version_repository.get_by_id(
  284. tenant_id=tenant_id,
  285. agent_version_id=agent_version_id,
  286. )
  287. return self.agent_version_repository.get_latest_published(
  288. tenant_id=tenant_id,
  289. agent_id=agent_id,
  290. )
  291. def _build_chat_messages(
  292. self,
  293. *,
  294. agent_run: AgentRun,
  295. agent_version: AgentVersion,
  296. memory_results: list[MemorySearchResultContract] | None = None,
  297. ) -> list[ChatMessageContract]:
  298. messages = [
  299. ChatMessageContract(role="system", content=agent_version.system_prompt),
  300. ]
  301. if agent_version.goal:
  302. messages.append(ChatMessageContract(role="system", content=f"Goal: {agent_version.goal}"))
  303. if memory_results:
  304. messages.append(
  305. ChatMessageContract(
  306. role="system",
  307. content=self._format_memory_context(memory_results),
  308. )
  309. )
  310. if agent_run.input_text:
  311. messages.append(ChatMessageContract(role="user", content=agent_run.input_text))
  312. if agent_run.input_json:
  313. messages.append(
  314. ChatMessageContract(
  315. role="user",
  316. content=f"Structured input: {agent_run.input_json}",
  317. )
  318. )
  319. return messages
  320. def _build_dry_run_output(self, *, agent_run: AgentRun, agent_version: AgentVersion) -> str:
  321. input_preview = agent_run.input_text or str(agent_run.input_json or {})
  322. return (
  323. f"[dry-run] Agent role={agent_version.role} "
  324. f"version={agent_version.version_no} received: {input_preview}"
  325. )
  326. def _read_optional_string(self, payload: dict[str, JSONValue], key: str) -> str | None:
  327. value = payload.get(key)
  328. if isinstance(value, str) and value:
  329. return value
  330. return None
  331. def _read_optional_float(self, payload: dict[str, JSONValue], key: str) -> float | None:
  332. value = payload.get(key)
  333. if isinstance(value, (int, float)) and not isinstance(value, bool):
  334. return float(value)
  335. return None
  336. def _read_optional_int(self, payload: dict[str, JSONValue], key: str) -> int | None:
  337. value = payload.get(key)
  338. if isinstance(value, int) and not isinstance(value, bool):
  339. return value
  340. return None
  341. def _read_relevant_memories(
  342. self,
  343. *,
  344. agent_run: AgentRun,
  345. agent_version: AgentVersion,
  346. ) -> tuple[list[MemorySearchResultContract], dict[str, JSONValue]]:
  347. if self.memory_client is None:
  348. return [], {"memory_read_enabled": False, "memory_read_reason": "client_missing"}
  349. if not self._read_bool(agent_version.memory_policy_json, "enabled", default=True):
  350. return [], {"memory_read_enabled": False, "memory_read_reason": "policy_disabled"}
  351. query = agent_run.input_text or str(agent_run.input_json or "")
  352. if not query:
  353. return [], {"memory_read_enabled": True, "memory_read_count": 0}
  354. scope = self._resolve_memory_scope(agent_run=agent_run, agent_version=agent_version)
  355. if scope is None:
  356. return [], {
  357. "memory_read_enabled": True,
  358. "memory_read_count": 0,
  359. "memory_read_reason": "scope_unavailable",
  360. }
  361. scope_type, scope_id = scope
  362. try:
  363. results = self.memory_client.search_memories(
  364. MemorySearchRequestContract(
  365. tenant_id=agent_run.tenant_id,
  366. query=query,
  367. scope_type=scope_type,
  368. scope_id=scope_id,
  369. owner_agent_id=agent_run.agent_id,
  370. session_id=agent_run.session_id,
  371. limit=self._read_int(
  372. agent_version.memory_policy_json,
  373. "read_top_k",
  374. default=8,
  375. ),
  376. )
  377. )
  378. except MemoryClientError as exc:
  379. return [], {
  380. "memory_read_enabled": True,
  381. "memory_read_count": 0,
  382. "memory_read_error": str(exc),
  383. }
  384. return results, {
  385. "memory_read_enabled": True,
  386. "memory_read_count": len(results),
  387. "memory_scope_type": scope_type,
  388. "memory_scope_id": scope_id,
  389. }
  390. def _write_interaction_memory(
  391. self,
  392. *,
  393. agent_run: AgentRun,
  394. agent_version: AgentVersion,
  395. output_text: str,
  396. ) -> dict[str, JSONValue]:
  397. if self.memory_client is None:
  398. return {"memory_write_enabled": False, "memory_write_reason": "client_missing"}
  399. if not self._read_bool(agent_version.memory_policy_json, "write_enabled", default=True):
  400. return {"memory_write_enabled": False, "memory_write_reason": "policy_disabled"}
  401. scope = self._resolve_memory_scope(agent_run=agent_run, agent_version=agent_version)
  402. if scope is None:
  403. return {"memory_write_enabled": True, "memory_write_reason": "scope_unavailable"}
  404. scope_type, scope_id = scope
  405. try:
  406. memory = self.memory_client.create_memory(
  407. MemoryCreateContract(
  408. tenant_id=agent_run.tenant_id,
  409. scope_type=scope_type,
  410. scope_id=scope_id,
  411. memory_type="conversation",
  412. content_text=self._format_interaction_memory(
  413. agent_run=agent_run,
  414. output_text=output_text,
  415. ),
  416. content_json={
  417. "agent_run_id": agent_run.id,
  418. "agent_version_id": agent_version.id,
  419. "input_text": agent_run.input_text,
  420. "output_text": output_text,
  421. },
  422. metadata_json={
  423. "source": "agent-service",
  424. "role": agent_version.role,
  425. "version_no": agent_version.version_no,
  426. },
  427. owner_agent_id=agent_run.agent_id,
  428. session_id=agent_run.session_id,
  429. source_ref=f"agent_run:{agent_run.id}",
  430. importance_score=self._read_nested_int(
  431. agent_version.memory_policy_json,
  432. "config_json",
  433. "write_importance_score",
  434. default=50,
  435. ),
  436. )
  437. )
  438. except MemoryClientError as exc:
  439. return {
  440. "memory_write_enabled": True,
  441. "memory_write_error": str(exc),
  442. }
  443. return {
  444. "memory_write_enabled": True,
  445. "memory_written_id": memory.id,
  446. "memory_scope_type": scope_type,
  447. "memory_scope_id": scope_id,
  448. }
  449. def _resolve_memory_scope(
  450. self,
  451. *,
  452. agent_run: AgentRun,
  453. agent_version: AgentVersion,
  454. ) -> tuple[MemoryScopeType, str] | None:
  455. scope_value = self._read_optional_string(
  456. agent_version.memory_policy_json,
  457. "memory_scope",
  458. ) or "session"
  459. if scope_value == "tenant":
  460. return "tenant", agent_run.tenant_id
  461. if scope_value == "agent":
  462. return "agent", agent_run.agent_id
  463. if scope_value == "session" and agent_run.session_id:
  464. return "session", agent_run.session_id
  465. if scope_value == "user":
  466. user_id = self._read_input_json_string(agent_run=agent_run, key="user_id")
  467. if user_id is not None:
  468. return "user", user_id
  469. if scope_value == "team":
  470. team_id = self._read_input_json_string(agent_run=agent_run, key="team_id")
  471. if team_id is not None:
  472. return "team", team_id
  473. return None
  474. def _format_memory_context(self, memory_results: list[MemorySearchResultContract]) -> str:
  475. lines = ["Relevant memories:"]
  476. for index, result in enumerate(memory_results, start=1):
  477. lines.append(f"{index}. {result.item.content_text}")
  478. return "\n".join(lines)
  479. def _format_interaction_memory(self, *, agent_run: AgentRun, output_text: str) -> str:
  480. input_text = agent_run.input_text or str(agent_run.input_json or {})
  481. return f"User input: {input_text}\nAgent output: {output_text}"
  482. def _read_bool(self, payload: dict[str, JSONValue], key: str, *, default: bool) -> bool:
  483. value = payload.get(key)
  484. if isinstance(value, bool):
  485. return value
  486. return default
  487. def _read_int(self, payload: dict[str, JSONValue], key: str, *, default: int) -> int:
  488. value = payload.get(key)
  489. if isinstance(value, int) and not isinstance(value, bool):
  490. return value
  491. return default
  492. def _read_nested_int(
  493. self,
  494. payload: dict[str, JSONValue],
  495. parent_key: str,
  496. child_key: str,
  497. *,
  498. default: int,
  499. ) -> int:
  500. parent_value = payload.get(parent_key)
  501. if not isinstance(parent_value, dict):
  502. return default
  503. return self._read_int(
  504. cast(dict[str, JSONValue], parent_value),
  505. child_key,
  506. default=default,
  507. )
  508. def _read_input_json_string(self, *, agent_run: AgentRun, key: str) -> str | None:
  509. if agent_run.input_json is None:
  510. return None
  511. value = agent_run.input_json.get(key)
  512. if isinstance(value, str) and value:
  513. return value
  514. return None
  515. def build_agent_application_service(
  516. *,
  517. db: Session,
  518. settings: AgentServiceSettings,
  519. ) -> AgentApplicationService:
  520. return AgentApplicationService(
  521. agent_repository=AgentDefinitionRepository(db),
  522. agent_version_repository=AgentVersionRepository(db),
  523. agent_run_repository=AgentRunRepository(db),
  524. model_gateway_client=ModelGatewayClient(
  525. base_url=settings.model_gateway_service_url,
  526. timeout_seconds=settings.model_gateway_timeout_seconds,
  527. ),
  528. memory_client=MemoryClient(
  529. base_url=settings.memory_service_url,
  530. timeout_seconds=settings.memory_service_timeout_seconds,
  531. ),
  532. )