test_team_service.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. from pathlib import Path
  2. from datetime import datetime
  3. from unittest.mock import MagicMock
  4. from tests.conftest import (
  5. build_fastapi_test_client,
  6. build_postgres_database_url,
  7. build_postgres_engine,
  8. prepare_known_service_import,
  9. )
  10. def test_team_service_post_contract_supports_team_configs_and_runs(
  11. tmp_path: Path,
  12. monkeypatch,
  13. ) -> None:
  14. prepare_known_service_import("team-service")
  15. from app.bootstrap.app import create_app
  16. from app.db.models import Base
  17. from core_db import create_session_factory
  18. database_url = build_postgres_database_url(tmp_path, "teams-api")
  19. monkeypatch.setenv("AGENT_PLATFORM_DATABASE_URL", database_url)
  20. monkeypatch.setenv("AGENT_PLATFORM_REDIS_URL", "")
  21. monkeypatch.setenv("AGENT_PLATFORM_AUTO_WORKER_ENABLED", "false")
  22. engine = build_postgres_engine(database_url)
  23. Base.metadata.create_all(engine)
  24. app = create_app()
  25. app.state.session_factory = create_session_factory(engine)
  26. client = build_fastapi_test_client(app)
  27. team_response = client.post(
  28. "/teams/create",
  29. json={
  30. "name": "Support Team",
  31. "description": "Handles support escalations",
  32. "teamType": "collaborative",
  33. "ownerUserId": "demo-user",
  34. },
  35. )
  36. assert team_response.status_code == 200
  37. team_payload = team_response.json()["data"]
  38. assert team_payload["name"] == "Support Team"
  39. assert team_payload["teamType"] == "collaborative"
  40. assert "code" not in team_payload
  41. config_response = client.post(
  42. "/teams/configs/create",
  43. json={
  44. "teamId": team_payload["id"],
  45. "coordinationMode": "supervisor",
  46. "objective": "Resolve the customer request",
  47. "memberRefs": [
  48. {
  49. "role": "worker",
  50. "agentId": "agent_support",
  51. "responsibility": "Draft the answer",
  52. }
  53. ],
  54. "policy": {
  55. "max_rounds": 3,
  56. "handoff": "supervisor",
  57. },
  58. },
  59. )
  60. assert config_response.status_code == 200
  61. config_payload = config_response.json()["data"]
  62. assert config_payload["teamId"] == team_payload["id"]
  63. assert config_payload["memberRefs"][0]["role"] == "specialist"
  64. assert config_payload["memberRefs"][0]["member_key"] == "member_1"
  65. list_response = client.post(
  66. "/teams/list",
  67. json={"page": 1, "pageSize": 20, "keyword": "support"},
  68. )
  69. assert list_response.status_code == 200
  70. assert list_response.json()["data"]["total"] == 1
  71. configs_response = client.post(
  72. "/teams/configs/list",
  73. json={"page": 1, "pageSize": 20, "teamId": team_payload["id"]},
  74. )
  75. assert configs_response.status_code == 200
  76. assert configs_response.json()["data"]["total"] == 1
  77. run_response = client.post(
  78. "/teams/runs/create",
  79. json={
  80. "teamId": team_payload["id"],
  81. "teamConfigId": config_payload["id"],
  82. "inputText": "Help the customer reset MFA",
  83. },
  84. )
  85. assert run_response.status_code == 200
  86. run_payload = run_response.json()["data"]
  87. assert run_payload["teamId"] == team_payload["id"]
  88. assert run_payload["teamConfigId"] == config_payload["id"]
  89. assert run_payload["status"] == "queued"
  90. status_response = client.post(
  91. "/teams/runs/status",
  92. json={
  93. "teamRunId": run_payload["id"],
  94. "status": "completed",
  95. "workerKey": "test-worker",
  96. "outputText": "MFA reset steps prepared.",
  97. },
  98. )
  99. assert status_response.status_code == 200
  100. assert status_response.json()["data"]["status"] == "completed"
  101. runs_response = client.post(
  102. "/teams/runs/list",
  103. json={"page": 1, "pageSize": 20, "teamId": team_payload["id"]},
  104. )
  105. assert runs_response.status_code == 200
  106. assert runs_response.json()["data"]["total"] == 1
  107. update_response = client.post(
  108. "/teams/update",
  109. json={
  110. "teamId": team_payload["id"],
  111. "name": "Support Team Updated",
  112. "status": "active",
  113. },
  114. )
  115. assert update_response.status_code == 200
  116. assert update_response.json()["data"]["name"] == "Support Team Updated"
  117. assert update_response.json()["data"]["status"] == "active"
  118. delete_run_response = client.post(
  119. "/teams/runs/delete",
  120. json={"teamRunId": run_payload["id"]},
  121. )
  122. assert delete_run_response.status_code == 200
  123. assert delete_run_response.json()["data"]["deleted"] is True
  124. delete_config_response = client.post(
  125. "/teams/configs/delete",
  126. json={"configId": config_payload["id"]},
  127. )
  128. assert delete_config_response.status_code == 200
  129. assert delete_config_response.json()["data"]["deleted"] is True
  130. delete_team_response = client.post(
  131. "/teams/delete",
  132. json={"teamId": team_payload["id"]},
  133. )
  134. assert delete_team_response.status_code == 200
  135. assert delete_team_response.json()["data"]["deleted"] is True
  136. def test_team_service_compacts_member_context_between_agent_calls() -> None:
  137. prepare_known_service_import("team-service")
  138. from app.application.services import TeamApplicationService, TeamMemberRunResult
  139. from core_domain import AgentRunContract, TeamMemberContract
  140. service = TeamApplicationService(
  141. team_repository=None,
  142. team_config_repository=None,
  143. team_run_repository=None)
  144. result = TeamMemberRunResult(
  145. member=TeamMemberContract(
  146. member_key="member_1",
  147. agent_id="agent_1",
  148. role="supervisor",
  149. name="Planner"),
  150. run=AgentRunContract(
  151. id="run_1",
  152. agent_id="agent_1",
  153. agent_config_id="config_1",
  154. output_text="x" * 2000,
  155. output_json={
  156. "model": "demo-model",
  157. "finish_reason": "stop",
  158. "messages": [{"role": "user", "content": "large prompt"}],
  159. "raw_response_json": {"thinking": "hidden debug payload"},
  160. },
  161. status="completed",
  162. created_time=datetime.utcnow()))
  163. prior_output = service._compact_prior_output(result)
  164. member_json = service._member_result_to_json(result)
  165. assert prior_output["output_text"].endswith("[truncated]")
  166. assert prior_output["output_json"] == {
  167. "model": "demo-model",
  168. "finish_reason": "stop",
  169. "debug_payload_omitted": True,
  170. }
  171. assert "messages" not in member_json["output_json"]
  172. assert "raw_response_json" not in member_json["output_json"]
  173. def _build_service_with_mock_agent() -> tuple:
  174. prepare_known_service_import("team-service")
  175. from app.application.services import TeamApplicationService, TeamMemberRunResult
  176. from core_domain import AgentRunContract, TeamMemberContract
  177. call_log: list[str] = []
  178. def make_member_result(member: TeamMemberContract, text: str) -> TeamMemberRunResult:
  179. return TeamMemberRunResult(
  180. member=member,
  181. run=AgentRunContract(
  182. id=f"run_{member.member_key}",
  183. agent_id=member.agent_id,
  184. agent_config_id=member.agent_config_id or "default_config",
  185. output_text=text,
  186. output_json={},
  187. status="completed",
  188. created_time=datetime.utcnow()))
  189. mock_client = MagicMock()
  190. mock_client.create_agent_run = MagicMock(
  191. side_effect=lambda **kw: AgentRunContract(
  192. id="run_mock", agent_id=kw.get("agent_id", "a"),
  193. status="created", created_time=datetime.utcnow()))
  194. mock_client.execute_agent_run = MagicMock(
  195. side_effect=lambda **kw: AgentRunContract(
  196. id=kw.get("agent_run_id", "run_mock"),
  197. agent_id="a", output_text="ok",
  198. output_json={}, status="completed",
  199. created_time=datetime.utcnow()))
  200. def track_execute(team_run, team_config, member, member_input_json, worker_key, dry_run):
  201. prior = member_input_json.get("prior_member_outputs", [])
  202. call_log.append(f"{member.member_key}:{member.role}:prior={len(prior)}")
  203. return make_member_result(member, f"output_{member.member_key}")
  204. service = TeamApplicationService(
  205. team_repository=None,
  206. team_config_repository=None,
  207. team_run_repository=None,
  208. agent_client=mock_client)
  209. return service, call_log, track_execute, make_member_result
  210. def _mock_team_run(**overrides):
  211. run = MagicMock()
  212. run.id = "run_test"
  213. run.team_id = "team_test"
  214. run.team_config_id = "config_test"
  215. run.session_id = None
  216. run.input_text = "test input"
  217. run.input_json = None
  218. for k, v in overrides.items():
  219. setattr(run, k, v)
  220. return run
  221. def _mock_team_config(**overrides):
  222. config = MagicMock()
  223. config.id = "config_test"
  224. config.team_id = "team_test"
  225. config.coordination_mode = "supervisor"
  226. config.objective = "test objective"
  227. config.policy_json = {}
  228. for k, v in overrides.items():
  229. setattr(config, k, v)
  230. return config
  231. def test_supervisor_mode_executes_lead_first_then_others() -> None:
  232. service, call_log, track_execute, _ = _build_service_with_mock_agent()
  233. from unittest.mock import patch
  234. from core_domain import TeamMemberContract
  235. members = [
  236. TeamMemberContract(member_key="worker_1", agent_id="a1", role="executor"),
  237. TeamMemberContract(member_key="lead_1", agent_id="a2", role="supervisor"),
  238. TeamMemberContract(member_key="worker_2", agent_id="a3", role="reviewer"),
  239. ]
  240. team_config = _mock_team_config(
  241. coordination_mode="supervisor",
  242. policy_json={"supervisor_synthesis": True})
  243. with patch.object(service, "_execute_single_member", side_effect=track_execute):
  244. results = service._execute_members(
  245. team_run=_mock_team_run(), team_config=team_config,
  246. members=members, worker_key=None, dry_run=False)
  247. # lead runs first, then workers, then synthesis = 4 executions
  248. assert len(results) == 4
  249. keys = [r.member.member_key for r in results]
  250. assert keys[0] == "lead_1" # supervisor first
  251. assert "worker_1" in keys[1:3]
  252. assert "worker_2" in keys[1:3]
  253. assert keys[3] == "lead_1" # synthesis pass
  254. def test_pipeline_mode_chains_single_prior_output() -> None:
  255. service, call_log, track_execute, _ = _build_service_with_mock_agent()
  256. from unittest.mock import patch
  257. from core_domain import TeamMemberContract
  258. members = [
  259. TeamMemberContract(member_key="m1", agent_id="a1", role="planner"),
  260. TeamMemberContract(member_key="m2", agent_id="a2", role="executor"),
  261. TeamMemberContract(member_key="m3", agent_id="a3", role="reviewer"),
  262. ]
  263. team_config = _mock_team_config(
  264. coordination_mode="pipeline",
  265. policy_json={})
  266. with patch.object(service, "_execute_single_member", side_effect=track_execute):
  267. results = service._execute_members(
  268. team_run=_mock_team_run(), team_config=team_config,
  269. members=members, worker_key=None, dry_run=False)
  270. assert len(results) == 3
  271. # m1: no prior, m2: 1 prior, m3: 1 prior (only previous, not all)
  272. assert call_log[0] == "m1:planner:prior=0"
  273. assert call_log[1] == "m2:executor:prior=1"
  274. assert call_log[2] == "m3:reviewer:prior=1"
  275. def test_debate_mode_executes_multiple_rounds() -> None:
  276. service, call_log, track_execute, _ = _build_service_with_mock_agent()
  277. from unittest.mock import patch
  278. from core_domain import TeamMemberContract
  279. members = [
  280. TeamMemberContract(member_key="m1", agent_id="a1", role="executor"),
  281. TeamMemberContract(member_key="m2", agent_id="a2", role="reviewer"),
  282. ]
  283. team_config = _mock_team_config(
  284. coordination_mode="debate",
  285. policy_json={"max_rounds": 3})
  286. with patch.object(service, "_execute_single_member", side_effect=track_execute):
  287. results = service._execute_members(
  288. team_run=_mock_team_run(), team_config=team_config,
  289. members=members, worker_key=None, dry_run=False)
  290. # 2 members x 3 rounds = 6 executions, final_results = last round
  291. assert len(results) == 2
  292. assert len(call_log) == 6
  293. # Round 1: prior=0 for first, prior=1 for second
  294. assert call_log[0] == "m1:executor:prior=0"
  295. assert call_log[1] == "m2:reviewer:prior=1"
  296. # Round 2: prior=2 (history from round 1)
  297. assert call_log[2] == "m1:executor:prior=2"
  298. assert call_log[3] == "m2:reviewer:prior=3"
  299. # Round 3: prior=4
  300. assert call_log[4] == "m1:executor:prior=4"
  301. def test_failure_mode_continue_allows_partial_failure() -> None:
  302. service, call_log, _, make_member_result = _build_service_with_mock_agent()
  303. from unittest.mock import patch
  304. from core_domain import TeamMemberContract, AgentRunContract
  305. members = [
  306. TeamMemberContract(member_key="m1", agent_id="a1", role="executor"),
  307. TeamMemberContract(member_key="m2", agent_id="a2", role="executor"),
  308. ]
  309. team_config = _mock_team_config(
  310. coordination_mode="supervisor",
  311. policy_json={"failure_mode": "continue_with_warning"})
  312. call_count = 0
  313. def track_with_failure(team_run, team_config, member, member_input_json, worker_key, dry_run):
  314. nonlocal call_count
  315. call_count += 1
  316. if member.member_key == "m1":
  317. from app.application.services import TeamMemberRunResult
  318. return TeamMemberRunResult(
  319. member=member,
  320. run=AgentRunContract(
  321. id="run_fail", agent_id="a1", agent_config_id="c1",
  322. status="failed", error_code="test_error",
  323. error_message="boom",
  324. created_time=datetime.utcnow()))
  325. return make_member_result(member, f"output_{member.member_key}")
  326. with patch.object(service, "_execute_single_member", side_effect=track_with_failure):
  327. results = service._execute_members(
  328. team_run=_mock_team_run(), team_config=team_config,
  329. members=members, worker_key=None, dry_run=False)
  330. # Both members executed despite first failing
  331. assert call_count == 2
  332. assert len(results) == 2
  333. def test_read_max_rounds_and_failure_mode_helpers() -> None:
  334. service, _, _, _ = _build_service_with_mock_agent()
  335. config_default = _mock_team_config(policy_json={})
  336. assert service._read_max_rounds(config_default) == 3
  337. assert service._read_failure_mode(config_default) == "stop_on_critical"
  338. config_custom = _mock_team_config(policy_json={
  339. "max_rounds": 5, "failure_mode": "continue_with_warning"})
  340. assert service._read_max_rounds(config_custom) == 5
  341. assert service._read_failure_mode(config_custom) == "continue_with_warning"
  342. config_clamped = _mock_team_config(policy_json={"max_rounds": 50})
  343. assert service._read_max_rounds(config_clamped) == 20