test_model_service.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. from pathlib import Path
  2. from tests.conftest import (
  3. build_fastapi_test_client,
  4. build_postgres_database_url,
  5. build_postgres_engine,
  6. prepare_known_service_import,
  7. )
  8. def test_model_service_post_contract_supports_models_and_providers(
  9. tmp_path: Path,
  10. monkeypatch,
  11. ) -> None:
  12. prepare_known_service_import("model-gateway-service")
  13. from app.bootstrap.app import create_app
  14. from app.db.models import Base
  15. from core_db import create_session_factory
  16. database_url = build_postgres_database_url(tmp_path, "models")
  17. monkeypatch.setenv("AGENT_PLATFORM_DATABASE_URL", database_url)
  18. engine = build_postgres_engine(database_url)
  19. Base.metadata.create_all(engine)
  20. app = create_app()
  21. app.state.session_factory = create_session_factory(engine)
  22. client = build_fastapi_test_client(app)
  23. provider_response = client.post(
  24. "/models/providers/create",
  25. json={
  26. "name": "Local OpenAI Compatible",
  27. "providerType": "openai_compatible",
  28. "baseUrl": "http://127.0.0.1:11434/v1",
  29. "apiKey": "local-secret",
  30. "models": [
  31. {
  32. "modelId": "llama3.1",
  33. "displayName": "Llama 3.1",
  34. "modelType": "chat",
  35. }
  36. ],
  37. "defaultModel": "llama3.1",
  38. },
  39. )
  40. assert provider_response.status_code == 200
  41. provider_payload = provider_response.json()["data"]
  42. assert provider_payload["apiKeyRef"] == "loc***masked"
  43. assert provider_payload["models"][0]["modelId"] == "llama3.1"
  44. providers_response = client.post(
  45. "/models/providers/list",
  46. json={"page": 1, "pageSize": 20},
  47. )
  48. assert providers_response.status_code == 200
  49. assert providers_response.json()["data"]["total"] == 1
  50. discover_response = client.post(
  51. "/models/providers/discover",
  52. json={"providerId": provider_payload["id"]},
  53. )
  54. assert discover_response.status_code == 200
  55. assert discover_response.json()["data"]["models"][0]["modelId"] == "llama3.1"
  56. model_response = client.post(
  57. "/models/create",
  58. json={
  59. "name": "Local Chat",
  60. "providerId": provider_payload["id"],
  61. "providerType": "openai_compatible",
  62. "modelName": "llama3.1",
  63. "capabilities": ["chat"],
  64. "timeoutSeconds": 30,
  65. },
  66. )
  67. assert model_response.status_code == 200
  68. model_payload = model_response.json()["data"]
  69. assert model_payload["modelName"] == "llama3.1"
  70. assert model_payload["providerId"] == provider_payload["id"]
  71. assert model_payload["providerBaseUrl"] == "http://127.0.0.1:11434/v1"
  72. assert model_payload["hasProviderApiKey"] is True
  73. assert "code" not in model_payload
  74. models_response = client.post(
  75. "/models/list",
  76. json={"page": 1, "pageSize": 20, "keyword": "local"},
  77. )
  78. assert models_response.status_code == 200
  79. assert models_response.json()["data"]["total"] == 1
  80. update_response = client.post(
  81. "/models/update",
  82. json={
  83. "modelId": model_payload["id"],
  84. "name": "Local Chat Updated",
  85. "defaultTemperature": 0.2,
  86. },
  87. )
  88. assert update_response.status_code == 200
  89. assert update_response.json()["data"]["defaultTemperature"] == 0.2
  90. delete_response = client.post(
  91. "/models/delete",
  92. json={"modelId": model_payload["id"]},
  93. )
  94. assert delete_response.status_code == 200
  95. assert delete_response.json()["data"]["deleted"] is True
  96. def test_model_provider_client_supports_anthropic_messages(monkeypatch) -> None:
  97. prepare_known_service_import("model-gateway-service")
  98. import app.infrastructure.provider as provider_module
  99. from app.bootstrap.settings import ModelGatewayServiceSettings
  100. from app.infrastructure.provider import ModelProviderClient
  101. from core_domain import ChatCompletionRequestContract
  102. captured: dict[str, object] = {}
  103. class FakeResponse:
  104. text = "{}"
  105. def raise_for_status(self) -> None:
  106. return None
  107. def json(self) -> dict[str, object]:
  108. return {
  109. "model": "claude-3-5-sonnet-20241022",
  110. "content": [{"type": "text", "text": "ready"}],
  111. "stop_reason": "end_turn",
  112. "usage": {"input_tokens": 12, "output_tokens": 3},
  113. }
  114. class FakeClient:
  115. def __init__(self, *, timeout: float) -> None:
  116. captured["timeout"] = timeout
  117. def __enter__(self) -> "FakeClient":
  118. return self
  119. def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
  120. return None
  121. def post(
  122. self,
  123. url: str,
  124. *,
  125. json: dict[str, object],
  126. headers: dict[str, str]) -> FakeResponse:
  127. captured["url"] = url
  128. captured["json"] = json
  129. captured["headers"] = headers
  130. return FakeResponse()
  131. monkeypatch.setattr(provider_module.httpx, "Client", FakeClient)
  132. client = ModelProviderClient(settings=ModelGatewayServiceSettings())
  133. response = client.create_chat_completion(
  134. ChatCompletionRequestContract(
  135. model="claude-3-5-sonnet-20241022",
  136. messages=[
  137. {"role": "system", "content": "Be concise."},
  138. {"role": "user", "content": "Ping"},
  139. ],
  140. max_tokens=128),
  141. provider_type="anthropic",
  142. provider_base_url="https://api.anthropic.com",
  143. provider_api_key="sk-test",
  144. timeout_seconds=15)
  145. assert captured["url"] == "https://api.anthropic.com/v1/messages"
  146. assert captured["headers"] == {
  147. "content-type": "application/json",
  148. "x-api-key": "sk-test",
  149. "anthropic-version": "2023-06-01",
  150. }
  151. assert captured["json"] == {
  152. "model": "claude-3-5-sonnet-20241022",
  153. "max_tokens": 128,
  154. "messages": [{"role": "user", "content": "Ping"}],
  155. "system": "Be concise.",
  156. }
  157. assert response.content == "ready"
  158. assert response.finish_reason == "end_turn"
  159. def test_model_service_backfills_legacy_model_connections_as_providers(
  160. tmp_path: Path,
  161. monkeypatch,
  162. ) -> None:
  163. prepare_known_service_import("model-gateway-service")
  164. from app.bootstrap.app import create_app
  165. from app.db.models import Base
  166. from core_db import create_session_factory
  167. database_url = build_postgres_database_url(tmp_path, "models-backfill")
  168. monkeypatch.setenv("AGENT_PLATFORM_DATABASE_URL", database_url)
  169. engine = build_postgres_engine(database_url)
  170. Base.metadata.create_all(engine)
  171. app = create_app()
  172. app.state.session_factory = create_session_factory(engine)
  173. client = build_fastapi_test_client(app)
  174. model_response = client.post(
  175. "/models/create",
  176. json={
  177. "name": "Legacy Anthropic",
  178. "providerType": "anthropic",
  179. "providerBaseUrl": "https://api.anthropic.com",
  180. "providerApiKey": "sk-legacy",
  181. "modelName": "claude-3-5-sonnet-20241022",
  182. "capabilities": ["chat"],
  183. },
  184. )
  185. assert model_response.status_code == 200
  186. assert model_response.json()["data"]["providerId"] is None
  187. providers_response = client.post(
  188. "/models/providers/list",
  189. json={"page": 1, "pageSize": 20},
  190. )
  191. assert providers_response.status_code == 200
  192. providers_payload = providers_response.json()["data"]
  193. assert providers_payload["total"] == 1
  194. provider_payload = providers_payload["items"][0]
  195. assert provider_payload["providerType"] == "anthropic"
  196. assert provider_payload["baseUrl"] == "https://api.anthropic.com"
  197. assert provider_payload["models"][0]["modelId"] == "claude-3-5-sonnet-20241022"
  198. models_response = client.post(
  199. "/models/list",
  200. json={"page": 1, "pageSize": 20},
  201. )
  202. assert models_response.status_code == 200
  203. assert models_response.json()["data"]["items"][0]["providerId"] == provider_payload["id"]