| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import httpx
- from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
- from core_shared import JSONValue
- from app.bootstrap.settings import ModelGatewayServiceSettings
- class ModelProviderClientError(Exception):
- pass
- class ModelProviderClient:
- def __init__(self, *, settings: ModelGatewayServiceSettings) -> None:
- self.settings = settings
- def create_chat_completion(
- self,
- payload: ChatCompletionRequestContract,
- *,
- provider_base_url: str | None = None,
- provider_api_key: str | None = None,
- timeout_seconds: float = 60.0,
- ) -> ChatCompletionResponseContract:
- if payload.model is None:
- raise ModelProviderClientError("model is required for chat completion")
- request_payload = {
- "model": payload.model,
- "messages": [item.model_dump(mode="json") for item in payload.messages],
- }
- if payload.temperature is not None:
- request_payload["temperature"] = payload.temperature
- if payload.max_tokens is not None:
- request_payload["max_tokens"] = payload.max_tokens
- if payload.tools_json:
- request_payload["tools"] = payload.tools_json
- if payload.tool_choice is not None:
- request_payload["tool_choice"] = payload.tool_choice
- request_headers: dict[str, str] = {"content-type": "application/json"}
- api_key = (
- provider_api_key
- if provider_api_key is not None
- else self.settings.provider_api_key
- )
- if api_key:
- request_headers["authorization"] = f"Bearer {api_key}"
- try:
- base_url = provider_base_url or self.settings.provider_base_url
- with httpx.Client(timeout=timeout_seconds) as client:
- response = client.post(
- f"{base_url.rstrip('/')}/chat/completions",
- json=request_payload,
- headers=request_headers)
- response.raise_for_status()
- except httpx.HTTPError as exc:
- raise ModelProviderClientError(f"model provider request failed: {exc}") from exc
- response_json = _coerce_json_dict(response.json())
- content = _extract_response_content(response_json)
- finish_reason = _extract_finish_reason(response_json)
- tool_calls_json = _extract_tool_calls_json(response_json)
- usage_json = _extract_usage_json(response_json)
- return ChatCompletionResponseContract(
- model=payload.model,
- content=content,
- finish_reason=finish_reason,
- tool_calls_json=tool_calls_json,
- usage_json=usage_json,
- raw_response_json=response_json)
- def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]:
- if isinstance(payload, dict):
- return {str(key): value for key, value in payload.items()}
- return {}
- def _extract_response_content(payload: dict[str, JSONValue]) -> str:
- choices = payload.get("choices")
- if isinstance(choices, list) and choices:
- first_choice = choices[0]
- if isinstance(first_choice, dict):
- message = first_choice.get("message")
- if isinstance(message, dict):
- content = message.get("content")
- if isinstance(content, str):
- return content
- text = first_choice.get("text")
- if isinstance(text, str):
- return text
- return ""
- def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
- choices = payload.get("choices")
- if isinstance(choices, list) and choices:
- first_choice = choices[0]
- if isinstance(first_choice, dict):
- finish_reason = first_choice.get("finish_reason")
- if isinstance(finish_reason, str):
- return finish_reason
- return None
- def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
- choices = payload.get("choices")
- if isinstance(choices, list) and choices:
- first_choice = choices[0]
- if isinstance(first_choice, dict):
- message = first_choice.get("message")
- if isinstance(message, dict):
- tool_calls = message.get("tool_calls")
- if isinstance(tool_calls, list):
- return [
- {str(item_key): item_value for item_key, item_value in item.items()}
- for item in tool_calls
- if isinstance(item, dict)
- ]
- return []
- def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
- usage = payload.get("usage")
- if isinstance(usage, dict):
- return {str(key): value for key, value in usage.items()}
- return {}
|