provider.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import httpx
  2. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  3. from core_shared import JSONValue
  4. from app.bootstrap.settings import ModelGatewayServiceSettings
  5. class ModelProviderClientError(Exception):
  6. pass
  7. class ModelProviderClient:
  8. def __init__(self, *, settings: ModelGatewayServiceSettings) -> None:
  9. self.settings = settings
  10. def create_chat_completion(
  11. self,
  12. payload: ChatCompletionRequestContract,
  13. *,
  14. provider_base_url: str | None = None,
  15. provider_api_key: str | None = None,
  16. timeout_seconds: float = 60.0,
  17. ) -> ChatCompletionResponseContract:
  18. if payload.model is None:
  19. raise ModelProviderClientError("model is required for chat completion")
  20. request_payload = {
  21. "model": payload.model,
  22. "messages": [item.model_dump(mode="json") for item in payload.messages],
  23. }
  24. if payload.temperature is not None:
  25. request_payload["temperature"] = payload.temperature
  26. if payload.max_tokens is not None:
  27. request_payload["max_tokens"] = payload.max_tokens
  28. if payload.tools_json:
  29. request_payload["tools"] = payload.tools_json
  30. if payload.tool_choice is not None:
  31. request_payload["tool_choice"] = payload.tool_choice
  32. request_headers: dict[str, str] = {"content-type": "application/json"}
  33. api_key = (
  34. provider_api_key
  35. if provider_api_key is not None
  36. else self.settings.provider_api_key
  37. )
  38. if api_key:
  39. request_headers["authorization"] = f"Bearer {api_key}"
  40. try:
  41. base_url = provider_base_url or self.settings.provider_base_url
  42. with httpx.Client(timeout=timeout_seconds) as client:
  43. response = client.post(
  44. f"{base_url.rstrip('/')}/chat/completions",
  45. json=request_payload,
  46. headers=request_headers)
  47. response.raise_for_status()
  48. except httpx.HTTPError as exc:
  49. raise ModelProviderClientError(f"model provider request failed: {exc}") from exc
  50. response_json = _coerce_json_dict(response.json())
  51. content = _extract_response_content(response_json)
  52. finish_reason = _extract_finish_reason(response_json)
  53. tool_calls_json = _extract_tool_calls_json(response_json)
  54. usage_json = _extract_usage_json(response_json)
  55. return ChatCompletionResponseContract(
  56. model=payload.model,
  57. content=content,
  58. finish_reason=finish_reason,
  59. tool_calls_json=tool_calls_json,
  60. usage_json=usage_json,
  61. raw_response_json=response_json)
  62. def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]:
  63. if isinstance(payload, dict):
  64. return {str(key): value for key, value in payload.items()}
  65. return {}
  66. def _extract_response_content(payload: dict[str, JSONValue]) -> str:
  67. choices = payload.get("choices")
  68. if isinstance(choices, list) and choices:
  69. first_choice = choices[0]
  70. if isinstance(first_choice, dict):
  71. message = first_choice.get("message")
  72. if isinstance(message, dict):
  73. content = message.get("content")
  74. if isinstance(content, str):
  75. return content
  76. text = first_choice.get("text")
  77. if isinstance(text, str):
  78. return text
  79. return ""
  80. def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
  81. choices = payload.get("choices")
  82. if isinstance(choices, list) and choices:
  83. first_choice = choices[0]
  84. if isinstance(first_choice, dict):
  85. finish_reason = first_choice.get("finish_reason")
  86. if isinstance(finish_reason, str):
  87. return finish_reason
  88. return None
  89. def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
  90. choices = payload.get("choices")
  91. if isinstance(choices, list) and choices:
  92. first_choice = choices[0]
  93. if isinstance(first_choice, dict):
  94. message = first_choice.get("message")
  95. if isinstance(message, dict):
  96. tool_calls = message.get("tool_calls")
  97. if isinstance(tool_calls, list):
  98. return [
  99. {str(item_key): item_value for item_key, item_value in item.items()}
  100. for item in tool_calls
  101. if isinstance(item, dict)
  102. ]
  103. return []
  104. def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
  105. usage = payload.get("usage")
  106. if isinstance(usage, dict):
  107. return {str(key): value for key, value in usage.items()}
  108. return {}