test_model_service.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. from pathlib import Path
  2. from tests.conftest import (
  3. build_fastapi_test_client,
  4. build_postgres_database_url,
  5. build_postgres_engine,
  6. prepare_known_service_import,
  7. )
  8. def test_model_service_post_contract_supports_models_and_providers(
  9. tmp_path: Path,
  10. monkeypatch,
  11. ) -> None:
  12. prepare_known_service_import("model-gateway-service")
  13. from app.bootstrap.app import create_app
  14. from app.db.models import Base
  15. from core_db import create_session_factory
  16. database_url = build_postgres_database_url(tmp_path, "models")
  17. monkeypatch.setenv("AGENT_PLATFORM_DATABASE_URL", database_url)
  18. engine = build_postgres_engine(database_url)
  19. Base.metadata.create_all(engine)
  20. app = create_app()
  21. app.state.session_factory = create_session_factory(engine)
  22. client = build_fastapi_test_client(app)
  23. provider_response = client.post(
  24. "/models/providers/create",
  25. json={
  26. "name": "Local OpenAI Compatible",
  27. "providerType": "openai_compatible",
  28. "baseUrl": "http://127.0.0.1:11434/v1",
  29. "apiKey": "local-secret",
  30. "models": [
  31. {
  32. "modelId": "llama3.1",
  33. "displayName": "Llama 3.1",
  34. "modelType": "chat",
  35. }
  36. ],
  37. "defaultModel": "llama3.1",
  38. },
  39. )
  40. assert provider_response.status_code == 200
  41. provider_payload = provider_response.json()["data"]
  42. assert provider_payload["apiKeyRef"] == "loc***masked"
  43. assert provider_payload["models"][0]["modelId"] == "llama3.1"
  44. auto_synced_response = client.post(
  45. "/models/list",
  46. json={"page": 1, "pageSize": 20},
  47. )
  48. assert auto_synced_response.status_code == 200
  49. auto_synced_payload = auto_synced_response.json()["data"]
  50. assert auto_synced_payload["total"] == 1
  51. assert auto_synced_payload["items"][0]["modelName"] == "llama3.1"
  52. providers_response = client.post(
  53. "/models/providers/list",
  54. json={"page": 1, "pageSize": 20},
  55. )
  56. assert providers_response.status_code == 200
  57. assert providers_response.json()["data"]["total"] == 1
  58. discover_response = client.post(
  59. "/models/providers/discover",
  60. json={"providerId": provider_payload["id"]},
  61. )
  62. assert discover_response.status_code == 200
  63. assert discover_response.json()["data"]["models"][0]["modelId"] == "llama3.1"
  64. synced_models_response = client.post(
  65. "/models/list",
  66. json={"page": 1, "pageSize": 20},
  67. )
  68. assert synced_models_response.status_code == 200
  69. synced_models_payload = synced_models_response.json()["data"]
  70. assert synced_models_payload["total"] == 1
  71. assert synced_models_payload["items"][0]["modelName"] == "llama3.1"
  72. assert synced_models_payload["items"][0]["providerId"] == provider_payload["id"]
  73. model_response = client.post(
  74. "/models/create",
  75. json={
  76. "name": "Local Chat",
  77. "providerId": provider_payload["id"],
  78. "providerType": "openai_compatible",
  79. "modelName": "llama3.1",
  80. "capabilities": ["chat"],
  81. "timeoutSeconds": 30,
  82. },
  83. )
  84. assert model_response.status_code == 200
  85. model_payload = model_response.json()["data"]
  86. assert model_payload["modelName"] == "llama3.1"
  87. assert model_payload["providerId"] == provider_payload["id"]
  88. assert model_payload["id"] == synced_models_payload["items"][0]["id"]
  89. assert model_payload["providerBaseUrl"] == "http://127.0.0.1:11434/v1"
  90. assert model_payload["hasProviderApiKey"] is True
  91. assert "code" not in model_payload
  92. models_response = client.post(
  93. "/models/list",
  94. json={"page": 1, "pageSize": 20, "keyword": "local"},
  95. )
  96. assert models_response.status_code == 200
  97. assert models_response.json()["data"]["total"] == 1
  98. update_response = client.post(
  99. "/models/update",
  100. json={
  101. "modelId": model_payload["id"],
  102. "name": "Local Chat Updated",
  103. "defaultTemperature": 0.2,
  104. },
  105. )
  106. assert update_response.status_code == 200
  107. assert update_response.json()["data"]["defaultTemperature"] == 0.2
  108. delete_response = client.post(
  109. "/models/delete",
  110. json={"modelId": model_payload["id"]},
  111. )
  112. assert delete_response.status_code == 200
  113. assert delete_response.json()["data"]["deleted"] is True
  114. def test_model_provider_client_supports_anthropic_messages(monkeypatch) -> None:
  115. prepare_known_service_import("model-gateway-service")
  116. import app.infrastructure.provider as provider_module
  117. from app.bootstrap.settings import ModelGatewayServiceSettings
  118. from app.infrastructure.provider import ModelProviderClient
  119. from core_domain import ChatCompletionRequestContract
  120. captured: dict[str, object] = {}
  121. class FakeResponse:
  122. text = "{}"
  123. def raise_for_status(self) -> None:
  124. return None
  125. def json(self) -> dict[str, object]:
  126. return {
  127. "model": "claude-3-5-sonnet-20241022",
  128. "content": [{"type": "text", "text": "ready"}],
  129. "stop_reason": "end_turn",
  130. "usage": {"input_tokens": 12, "output_tokens": 3},
  131. }
  132. class FakeClient:
  133. def __init__(self, *, timeout: float) -> None:
  134. captured["timeout"] = timeout
  135. def __enter__(self) -> "FakeClient":
  136. return self
  137. def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
  138. return None
  139. def post(
  140. self,
  141. url: str,
  142. *,
  143. json: dict[str, object],
  144. headers: dict[str, str]) -> FakeResponse:
  145. captured["url"] = url
  146. captured["json"] = json
  147. captured["headers"] = headers
  148. return FakeResponse()
  149. monkeypatch.setattr(provider_module.httpx, "Client", FakeClient)
  150. client = ModelProviderClient(settings=ModelGatewayServiceSettings())
  151. response = client.create_chat_completion(
  152. ChatCompletionRequestContract(
  153. model="claude-3-5-sonnet-20241022",
  154. messages=[
  155. {"role": "system", "content": "Be concise."},
  156. {"role": "user", "content": "Ping"},
  157. ],
  158. max_tokens=128),
  159. provider_type="anthropic",
  160. provider_base_url="https://api.anthropic.com",
  161. provider_api_key="sk-test",
  162. timeout_seconds=15)
  163. assert captured["url"] == "https://api.anthropic.com/v1/messages"
  164. assert captured["headers"] == {
  165. "content-type": "application/json",
  166. "x-api-key": "sk-test",
  167. "anthropic-version": "2023-06-01",
  168. }
  169. assert captured["json"] == {
  170. "model": "claude-3-5-sonnet-20241022",
  171. "max_tokens": 128,
  172. "messages": [{"role": "user", "content": "Ping"}],
  173. "system": "Be concise.",
  174. }
  175. assert response.content == "ready"
  176. assert response.finish_reason == "end_turn"
  177. def test_model_service_backfills_legacy_model_connections_as_providers(
  178. tmp_path: Path,
  179. monkeypatch,
  180. ) -> None:
  181. prepare_known_service_import("model-gateway-service")
  182. from app.bootstrap.app import create_app
  183. from app.db.models import Base
  184. from core_db import create_session_factory
  185. database_url = build_postgres_database_url(tmp_path, "models-backfill")
  186. monkeypatch.setenv("AGENT_PLATFORM_DATABASE_URL", database_url)
  187. engine = build_postgres_engine(database_url)
  188. Base.metadata.create_all(engine)
  189. app = create_app()
  190. app.state.session_factory = create_session_factory(engine)
  191. client = build_fastapi_test_client(app)
  192. model_response = client.post(
  193. "/models/create",
  194. json={
  195. "name": "Legacy Anthropic",
  196. "providerType": "anthropic",
  197. "providerBaseUrl": "https://api.anthropic.com",
  198. "providerApiKey": "sk-legacy",
  199. "modelName": "claude-3-5-sonnet-20241022",
  200. "capabilities": ["chat"],
  201. },
  202. )
  203. assert model_response.status_code == 200
  204. assert model_response.json()["data"]["providerId"] is None
  205. providers_response = client.post(
  206. "/models/providers/list",
  207. json={"page": 1, "pageSize": 20},
  208. )
  209. assert providers_response.status_code == 200
  210. providers_payload = providers_response.json()["data"]
  211. assert providers_payload["total"] == 1
  212. provider_payload = providers_payload["items"][0]
  213. assert provider_payload["providerType"] == "anthropic"
  214. assert provider_payload["baseUrl"] == "https://api.anthropic.com"
  215. assert provider_payload["models"][0]["modelId"] == "claude-3-5-sonnet-20241022"
  216. models_response = client.post(
  217. "/models/list",
  218. json={"page": 1, "pageSize": 20},
  219. )
  220. assert models_response.status_code == 200
  221. assert models_response.json()["data"]["items"][0]["providerId"] == provider_payload["id"]