provider.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import httpx
  2. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  3. from core_shared import JSONValue
  4. from app.bootstrap.settings import ModelGatewayServiceSettings
  5. class ModelProviderClientError(Exception):
  6. pass
  7. class ModelProviderClient:
  8. def __init__(self, *, settings: ModelGatewayServiceSettings) -> None:
  9. self.settings = settings
  10. def create_chat_completion(
  11. self,
  12. payload: ChatCompletionRequestContract,
  13. ) -> ChatCompletionResponseContract:
  14. if payload.model is None:
  15. raise ModelProviderClientError("model is required for chat completion")
  16. request_payload = {
  17. "model": payload.model,
  18. "messages": [item.model_dump(mode="json") for item in payload.messages],
  19. }
  20. if payload.temperature is not None:
  21. request_payload["temperature"] = payload.temperature
  22. if payload.max_tokens is not None:
  23. request_payload["max_tokens"] = payload.max_tokens
  24. request_headers: dict[str, str] = {"content-type": "application/json"}
  25. if self.settings.provider_api_key:
  26. request_headers["authorization"] = f"Bearer {self.settings.provider_api_key}"
  27. try:
  28. with httpx.Client(timeout=60.0) as client:
  29. response = client.post(
  30. f"{self.settings.provider_base_url.rstrip('/')}/chat/completions",
  31. json=request_payload,
  32. headers=request_headers,
  33. )
  34. response.raise_for_status()
  35. except httpx.HTTPError as exc:
  36. raise ModelProviderClientError(f"model provider request failed: {exc}") from exc
  37. response_json = _coerce_json_dict(response.json())
  38. content = _extract_response_content(response_json)
  39. finish_reason = _extract_finish_reason(response_json)
  40. usage_json = _extract_usage_json(response_json)
  41. return ChatCompletionResponseContract(
  42. model=payload.model,
  43. content=content,
  44. finish_reason=finish_reason,
  45. usage_json=usage_json,
  46. raw_response_json=response_json,
  47. )
  48. def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]:
  49. if isinstance(payload, dict):
  50. return {str(key): value for key, value in payload.items()}
  51. return {}
  52. def _extract_response_content(payload: dict[str, JSONValue]) -> str:
  53. choices = payload.get("choices")
  54. if isinstance(choices, list) and choices:
  55. first_choice = choices[0]
  56. if isinstance(first_choice, dict):
  57. message = first_choice.get("message")
  58. if isinstance(message, dict):
  59. content = message.get("content")
  60. if isinstance(content, str):
  61. return content
  62. text = first_choice.get("text")
  63. if isinstance(text, str):
  64. return text
  65. return ""
  66. def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
  67. choices = payload.get("choices")
  68. if isinstance(choices, list) and choices:
  69. first_choice = choices[0]
  70. if isinstance(first_choice, dict):
  71. finish_reason = first_choice.get("finish_reason")
  72. if isinstance(finish_reason, str):
  73. return finish_reason
  74. return None
  75. def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
  76. usage = payload.get("usage")
  77. if isinstance(usage, dict):
  78. return {str(key): value for key, value in usage.items()}
  79. return {}