test_team_service.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. from pathlib import Path
  2. from datetime import datetime
  3. from tests.conftest import (
  4. build_fastapi_test_client,
  5. build_postgres_database_url,
  6. build_postgres_engine,
  7. prepare_known_service_import,
  8. )
  9. def test_team_service_post_contract_supports_team_configs_and_runs(
  10. tmp_path: Path,
  11. monkeypatch,
  12. ) -> None:
  13. prepare_known_service_import("team-service")
  14. from app.bootstrap.app import create_app
  15. from app.db.models import Base
  16. from core_db import create_session_factory
  17. database_url = build_postgres_database_url(tmp_path, "teams-api")
  18. monkeypatch.setenv("AGENT_PLATFORM_DATABASE_URL", database_url)
  19. monkeypatch.setenv("AGENT_PLATFORM_REDIS_URL", "")
  20. monkeypatch.setenv("AGENT_PLATFORM_AUTO_WORKER_ENABLED", "false")
  21. engine = build_postgres_engine(database_url)
  22. Base.metadata.create_all(engine)
  23. app = create_app()
  24. app.state.session_factory = create_session_factory(engine)
  25. client = build_fastapi_test_client(app)
  26. team_response = client.post(
  27. "/teams/create",
  28. json={
  29. "name": "Support Team",
  30. "description": "Handles support escalations",
  31. "teamType": "collaborative",
  32. "ownerUserId": "demo-user",
  33. },
  34. )
  35. assert team_response.status_code == 200
  36. team_payload = team_response.json()["data"]
  37. assert team_payload["name"] == "Support Team"
  38. assert team_payload["teamType"] == "collaborative"
  39. assert "code" not in team_payload
  40. config_response = client.post(
  41. "/teams/configs/create",
  42. json={
  43. "teamId": team_payload["id"],
  44. "coordinationMode": "supervisor",
  45. "objective": "Resolve the customer request",
  46. "memberRefs": [
  47. {
  48. "role": "worker",
  49. "agentId": "agent_support",
  50. "responsibility": "Draft the answer",
  51. }
  52. ],
  53. "policy": {
  54. "max_rounds": 3,
  55. "handoff": "supervisor",
  56. },
  57. },
  58. )
  59. assert config_response.status_code == 200
  60. config_payload = config_response.json()["data"]
  61. assert config_payload["teamId"] == team_payload["id"]
  62. assert config_payload["memberRefs"][0]["role"] == "specialist"
  63. assert config_payload["memberRefs"][0]["member_key"] == "member_1"
  64. list_response = client.post(
  65. "/teams/list",
  66. json={"page": 1, "pageSize": 20, "keyword": "support"},
  67. )
  68. assert list_response.status_code == 200
  69. assert list_response.json()["data"]["total"] == 1
  70. configs_response = client.post(
  71. "/teams/configs/list",
  72. json={"page": 1, "pageSize": 20, "teamId": team_payload["id"]},
  73. )
  74. assert configs_response.status_code == 200
  75. assert configs_response.json()["data"]["total"] == 1
  76. run_response = client.post(
  77. "/teams/runs/create",
  78. json={
  79. "teamId": team_payload["id"],
  80. "teamConfigId": config_payload["id"],
  81. "inputText": "Help the customer reset MFA",
  82. },
  83. )
  84. assert run_response.status_code == 200
  85. run_payload = run_response.json()["data"]
  86. assert run_payload["teamId"] == team_payload["id"]
  87. assert run_payload["teamConfigId"] == config_payload["id"]
  88. assert run_payload["status"] == "queued"
  89. status_response = client.post(
  90. "/teams/runs/status",
  91. json={
  92. "teamRunId": run_payload["id"],
  93. "status": "completed",
  94. "workerKey": "test-worker",
  95. "outputText": "MFA reset steps prepared.",
  96. },
  97. )
  98. assert status_response.status_code == 200
  99. assert status_response.json()["data"]["status"] == "completed"
  100. runs_response = client.post(
  101. "/teams/runs/list",
  102. json={"page": 1, "pageSize": 20, "teamId": team_payload["id"]},
  103. )
  104. assert runs_response.status_code == 200
  105. assert runs_response.json()["data"]["total"] == 1
  106. update_response = client.post(
  107. "/teams/update",
  108. json={
  109. "teamId": team_payload["id"],
  110. "name": "Support Team Updated",
  111. "status": "active",
  112. },
  113. )
  114. assert update_response.status_code == 200
  115. assert update_response.json()["data"]["name"] == "Support Team Updated"
  116. assert update_response.json()["data"]["status"] == "active"
  117. delete_run_response = client.post(
  118. "/teams/runs/delete",
  119. json={"teamRunId": run_payload["id"]},
  120. )
  121. assert delete_run_response.status_code == 200
  122. assert delete_run_response.json()["data"]["deleted"] is True
  123. delete_config_response = client.post(
  124. "/teams/configs/delete",
  125. json={"configId": config_payload["id"]},
  126. )
  127. assert delete_config_response.status_code == 200
  128. assert delete_config_response.json()["data"]["deleted"] is True
  129. delete_team_response = client.post(
  130. "/teams/delete",
  131. json={"teamId": team_payload["id"]},
  132. )
  133. assert delete_team_response.status_code == 200
  134. assert delete_team_response.json()["data"]["deleted"] is True
  135. def test_team_service_compacts_member_context_between_agent_calls() -> None:
  136. prepare_known_service_import("team-service")
  137. from app.application.services import TeamApplicationService, TeamMemberRunResult
  138. from core_domain import AgentRunContract, TeamMemberContract
  139. service = TeamApplicationService(
  140. team_repository=None,
  141. team_config_repository=None,
  142. team_run_repository=None)
  143. result = TeamMemberRunResult(
  144. member=TeamMemberContract(
  145. member_key="member_1",
  146. agent_id="agent_1",
  147. role="supervisor",
  148. name="Planner"),
  149. run=AgentRunContract(
  150. id="run_1",
  151. agent_id="agent_1",
  152. agent_config_id="config_1",
  153. output_text="x" * 2000,
  154. output_json={
  155. "model": "demo-model",
  156. "finish_reason": "stop",
  157. "messages": [{"role": "user", "content": "large prompt"}],
  158. "raw_response_json": {"thinking": "hidden debug payload"},
  159. },
  160. status="completed",
  161. created_time=datetime.utcnow()))
  162. prior_output = service._compact_prior_output(result)
  163. member_json = service._member_result_to_json(result)
  164. assert prior_output["output_text"].endswith("[truncated]")
  165. assert prior_output["output_json"] == {
  166. "model": "demo-model",
  167. "finish_reason": "stop",
  168. "debug_payload_omitted": True,
  169. }
  170. assert "messages" not in member_json["output_json"]
  171. assert "raw_response_json" not in member_json["output_json"]
  172. def _build_service_with_mock_agent() -> tuple:
  173. prepare_known_service_import("team-service")
  174. from unittest.mock import MagicMock
  175. from app.application.services import TeamApplicationService, TeamMemberRunResult
  176. from core_domain import AgentRunContract, TeamMemberContract
  177. call_log: list[str] = []
  178. def make_member_result(member: TeamMemberContract, text: str) -> TeamMemberRunResult:
  179. return TeamMemberRunResult(
  180. member=member,
  181. run=AgentRunContract(
  182. id=f"run_{member.member_key}",
  183. agent_id=member.agent_id,
  184. agent_config_id=member.agent_config_id,
  185. output_text=text,
  186. output_json={},
  187. status="completed",
  188. created_time=datetime.utcnow()))
  189. mock_client = MagicMock()
  190. mock_client.create_agent_run = MagicMock(
  191. side_effect=lambda **kw: AgentRunContract(
  192. id="run_mock", agent_id=kw.get("agent_id", "a"),
  193. status="created", created_time=datetime.utcnow()))
  194. mock_client.execute_agent_run = MagicMock(
  195. side_effect=lambda **kw: AgentRunContract(
  196. id=kw.get("agent_run_id", "run_mock"),
  197. agent_id="a", output_text="ok",
  198. output_json={}, status="completed",
  199. created_time=datetime.utcnow()))
  200. def track_execute(team_run, team_config, member, member_input_json, worker_key, dry_run):
  201. prior = member_input_json.get("prior_member_outputs", [])
  202. call_log.append(f"{member.member_key}:{member.role}:prior={len(prior)}")
  203. return make_member_result(member, f"output_{member.member_key}")
  204. service = TeamApplicationService(
  205. team_repository=None,
  206. team_config_repository=None,
  207. team_run_repository=None,
  208. agent_client=mock_client)
  209. return service, call_log, track_execute
  210. def test_supervisor_mode_executes_lead_first_then_others() -> None:
  211. service, call_log, track_execute = _build_service_with_mock_agent()
  212. from unittest.mock import patch
  213. from core_domain import TeamMemberContract
  214. members = [
  215. TeamMemberContract(member_key="worker_1", agent_id="a1", role="executor"),
  216. TeamMemberContract(member_key="lead_1", agent_id="a2", role="supervisor"),
  217. TeamMemberContract(member_key="worker_2", agent_id="a3", role="reviewer"),
  218. ]
  219. team_config = type("C", (), {
  220. "coordination_mode": "supervisor",
  221. "objective": "test",
  222. "policy_json": {"supervisor_synthesis": True},
  223. })()
  224. with patch.object(service, "_execute_single_member", side_effect=track_execute):
  225. results = service._execute_members(
  226. team_run=MagicMock(), team_config=team_config,
  227. members=members, worker_key=None, dry_run=False)
  228. # lead runs first, then workers, then synthesis = 4 executions
  229. assert len(results) == 4
  230. keys = [r.member.member_key for r in results]
  231. assert keys[0] == "lead_1" # supervisor first
  232. assert "worker_1" in keys[1:3]
  233. assert "worker_2" in keys[1:3]
  234. assert keys[3] == "lead_1" # synthesis pass
  235. def test_pipeline_mode_chains_single_prior_output() -> None:
  236. service, call_log, track_execute = _build_service_with_mock_agent()
  237. from unittest.mock import patch
  238. from core_domain import TeamMemberContract
  239. members = [
  240. TeamMemberContract(member_key="m1", agent_id="a1", role="planner"),
  241. TeamMemberContract(member_key="m2", agent_id="a2", role="executor"),
  242. TeamMemberContract(member_key="m3", agent_id="a3", role="reviewer"),
  243. ]
  244. team_config = type("C", (), {
  245. "coordination_mode": "pipeline",
  246. "objective": "test",
  247. "policy_json": {},
  248. })()
  249. with patch.object(service, "_execute_single_member", side_effect=track_execute):
  250. results = service._execute_members(
  251. team_run=MagicMock(), team_config=team_config,
  252. members=members, worker_key=None, dry_run=False)
  253. assert len(results) == 3
  254. # m1: no prior, m2: 1 prior, m3: 1 prior (only previous, not all)
  255. assert call_log[0] == "m1:planner:prior=0"
  256. assert call_log[1] == "m2:executor:prior=1"
  257. assert call_log[2] == "m3:reviewer:prior=1"
  258. def test_debate_mode_executes_multiple_rounds() -> None:
  259. service, call_log, track_execute = _build_service_with_mock_agent()
  260. from unittest.mock import patch
  261. from core_domain import TeamMemberContract
  262. members = [
  263. TeamMemberContract(member_key="m1", agent_id="a1", role="executor"),
  264. TeamMemberContract(member_key="m2", agent_id="a2", role="reviewer"),
  265. ]
  266. team_config = type("C", (), {
  267. "coordination_mode": "debate",
  268. "objective": "test",
  269. "policy_json": {"max_rounds": 3},
  270. })()
  271. with patch.object(service, "_execute_single_member", side_effect=track_execute):
  272. results = service._execute_members(
  273. team_run=MagicMock(), team_config=team_config,
  274. members=members, worker_key=None, dry_run=False)
  275. # 2 members x 3 rounds = 6 executions, final_results = last round
  276. assert len(results) == 2
  277. assert len(call_log) == 6
  278. # Round 1: prior=0 for first, prior=1 for second
  279. assert call_log[0] == "m1:executor:prior=0"
  280. assert call_log[1] == "m2:reviewer:prior=1"
  281. # Round 2: prior=2 (history from round 1)
  282. assert call_log[2] == "m1:executor:prior=2"
  283. assert call_log[3] == "m2:reviewer:prior=3"
  284. # Round 3: prior=4
  285. assert call_log[4] == "m1:executor:prior=4"
  286. def test_failure_mode_continue_allows_partial_failure() -> None:
  287. service, call_log, track_execute = _build_service_with_mock_agent()
  288. from unittest.mock import patch, MagicMock
  289. from core_domain import TeamMemberContract, AgentRunContract
  290. members = [
  291. TeamMemberContract(member_key="m1", agent_id="a1", role="executor"),
  292. TeamMemberContract(member_key="m2", agent_id="a2", role="executor"),
  293. ]
  294. team_config = type("C", (), {
  295. "coordination_mode": "supervisor",
  296. "objective": "test",
  297. "policy_json": {"failure_mode": "continue_with_warning"},
  298. })()
  299. call_count = 0
  300. def track_with_failure(team_run, team_config, member, member_input_json, worker_key, dry_run):
  301. nonlocal call_count
  302. call_count += 1
  303. if member.member_key == "m1":
  304. return TeamMemberRunResult(
  305. member=member,
  306. run=AgentRunContract(
  307. id="run_fail", agent_id="a1",
  308. status="failed", error_code="test_error",
  309. error_message="boom",
  310. created_time=datetime.utcnow()))
  311. return make_member_result(member, f"output_{member.member_key}")
  312. def make_member_result(member, text):
  313. return TeamMemberRunResult(
  314. member=member,
  315. run=AgentRunContract(
  316. id=f"run_{member.member_key}", agent_id=member.agent_id,
  317. output_text=text, output_json={},
  318. status="completed", created_time=datetime.utcnow()))
  319. with patch.object(service, "_execute_single_member", side_effect=track_with_failure):
  320. results = service._execute_members(
  321. team_run=MagicMock(), team_config=team_config,
  322. members=members, worker_key=None, dry_run=False)
  323. # Both members executed despite first failing
  324. assert call_count == 2
  325. assert len(results) == 2
  326. def test_read_max_rounds_and_failure_mode_helpers() -> None:
  327. service, _, _ = _build_service_with_mock_agent()
  328. config_default = type("C", (), {"policy_json": {}})()
  329. assert service._read_max_rounds(config_default) == 3
  330. assert service._read_failure_mode(config_default) == "stop_on_critical"
  331. config_custom = type("C", (), {"policy_json": {
  332. "max_rounds": 5, "failure_mode": "continue_with_warning"}})()
  333. assert service._read_max_rounds(config_custom) == 5
  334. assert service._read_failure_mode(config_custom) == "continue_with_warning"
  335. config_clamped = type("C", (), {"policy_json": {"max_rounds": 50}})()
  336. assert service._read_max_rounds(config_clamped) == 20