import httpx from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract from core_shared import JSONValue from app.bootstrap.settings import ModelGatewayServiceSettings class ModelProviderClientError(Exception): pass class ModelProviderClient: def __init__(self, *, settings: ModelGatewayServiceSettings) -> None: self.settings = settings def create_chat_completion( self, payload: ChatCompletionRequestContract, *, provider_type: str | None = None, provider_base_url: str | None = None, provider_api_key: str | None = None, timeout_seconds: float = 60.0, ) -> ChatCompletionResponseContract: if payload.model is None: raise ModelProviderClientError("model is required for chat completion") resolved_provider_type = provider_type or self.settings.provider_type if resolved_provider_type == "anthropic": return self._create_anthropic_message( payload, provider_base_url=provider_base_url, provider_api_key=provider_api_key, timeout_seconds=timeout_seconds) return self._create_openai_compatible_chat_completion( payload, provider_base_url=provider_base_url, provider_api_key=provider_api_key, timeout_seconds=timeout_seconds) def _create_openai_compatible_chat_completion( self, payload: ChatCompletionRequestContract, *, provider_base_url: str | None, provider_api_key: str | None, timeout_seconds: float) -> ChatCompletionResponseContract: request_payload: dict[str, JSONValue] = { "model": payload.model or "", "messages": [item.model_dump(mode="json") for item in payload.messages], } if payload.temperature is not None: request_payload["temperature"] = payload.temperature if payload.max_tokens is not None: request_payload["max_tokens"] = payload.max_tokens if payload.tools_json: request_payload["tools"] = payload.tools_json if payload.tool_choice is not None: request_payload["tool_choice"] = payload.tool_choice request_headers: dict[str, str] = {"content-type": "application/json"} api_key = ( provider_api_key if provider_api_key is not None else self.settings.provider_api_key ) if api_key: request_headers["authorization"] = f"Bearer {api_key}" try: base_url = provider_base_url or self.settings.provider_base_url with httpx.Client(timeout=timeout_seconds) as client: response = client.post( _join_url(base_url, "chat/completions"), json=request_payload, headers=request_headers) response.raise_for_status() except httpx.HTTPStatusError as exc: detail = exc.response.text[:1000] raise ModelProviderClientError( f"model provider request failed: {exc.response.status_code} {detail}") from exc except httpx.HTTPError as exc: raise ModelProviderClientError(f"model provider request failed: {exc}") from exc response_json = _coerce_json_dict(response.json()) content = _extract_response_content(response_json) finish_reason = _extract_finish_reason(response_json) tool_calls_json = _extract_tool_calls_json(response_json) usage_json = _extract_usage_json(response_json) return ChatCompletionResponseContract( model=payload.model, content=content, finish_reason=finish_reason, tool_calls_json=tool_calls_json, usage_json=usage_json, raw_response_json=response_json) def _create_anthropic_message( self, payload: ChatCompletionRequestContract, *, provider_base_url: str | None, provider_api_key: str | None, timeout_seconds: float) -> ChatCompletionResponseContract: api_key = ( provider_api_key if provider_api_key is not None else self.settings.provider_api_key ) if not api_key: raise ModelProviderClientError("anthropic api key is required") system_prompt, messages = _to_anthropic_messages(payload) request_payload: dict[str, JSONValue] = { "model": payload.model or "", "max_tokens": payload.max_tokens or 1024, "messages": messages, } if system_prompt: request_payload["system"] = system_prompt if payload.temperature is not None: request_payload["temperature"] = payload.temperature request_headers = { "content-type": "application/json", "x-api-key": api_key, "anthropic-version": "2023-06-01", } try: base_url = provider_base_url or self.settings.provider_base_url with httpx.Client(timeout=timeout_seconds) as client: response = client.post( _join_url(base_url, "v1/messages"), json=request_payload, headers=request_headers) response.raise_for_status() except httpx.HTTPStatusError as exc: detail = exc.response.text[:1000] raise ModelProviderClientError( f"anthropic request failed: {exc.response.status_code} {detail}") from exc except httpx.HTTPError as exc: raise ModelProviderClientError(f"anthropic request failed: {exc}") from exc response_json = _coerce_json_dict(response.json()) return ChatCompletionResponseContract( model=_read_string(response_json, "model") or payload.model, content=_extract_anthropic_content(response_json), finish_reason=_read_string(response_json, "stop_reason"), tool_calls_json=[], usage_json=_extract_usage_json(response_json), raw_response_json=response_json) def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]: if isinstance(payload, dict): return {str(key): value for key, value in payload.items()} return {} def _join_url(base_url: str, path: str) -> str: normalized_base = base_url.rstrip("/") normalized_path = path.strip("/") if normalized_path.startswith("v1/") and normalized_base.endswith("/v1"): normalized_path = normalized_path.removeprefix("v1/") return f"{normalized_base}/{normalized_path}" def _to_anthropic_messages( payload: ChatCompletionRequestContract) -> tuple[str | None, list[dict[str, JSONValue]]]: system_parts: list[str] = [] messages: list[dict[str, JSONValue]] = [] for message in payload.messages: if message.role == "system": system_parts.append(message.content) continue role = "assistant" if message.role == "assistant" else "user" if messages and messages[-1].get("role") == role: previous = messages[-1].get("content") messages[-1]["content"] = f"{previous}\n\n{message.content}" if isinstance(previous, str) else message.content else: messages.append({"role": role, "content": message.content}) if not messages: messages.append({"role": "user", "content": ""}) return ("\n\n".join(system_parts) if system_parts else None), messages def _extract_anthropic_content(payload: dict[str, JSONValue]) -> str: content = payload.get("content") if isinstance(content, str): return content if not isinstance(content, list): return "" parts: list[str] = [] for item in content: if not isinstance(item, dict): continue text = item.get("text") if isinstance(text, str): parts.append(text) return "\n".join(parts) def _read_string(payload: dict[str, JSONValue], key: str) -> str | None: value = payload.get(key) return value if isinstance(value, str) else None def _extract_response_content(payload: dict[str, JSONValue]) -> str: choices = payload.get("choices") if isinstance(choices, list) and choices: first_choice = choices[0] if isinstance(first_choice, dict): message = first_choice.get("message") if isinstance(message, dict): content = message.get("content") if isinstance(content, str): return content text = first_choice.get("text") if isinstance(text, str): return text return "" def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None: choices = payload.get("choices") if isinstance(choices, list) and choices: first_choice = choices[0] if isinstance(first_choice, dict): finish_reason = first_choice.get("finish_reason") if isinstance(finish_reason, str): return finish_reason return None def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]: choices = payload.get("choices") if isinstance(choices, list) and choices: first_choice = choices[0] if isinstance(first_choice, dict): message = first_choice.get("message") if isinstance(message, dict): tool_calls = message.get("tool_calls") if isinstance(tool_calls, list): return [ {str(item_key): item_value for item_key, item_value in item.items()} for item in tool_calls if isinstance(item, dict) ] return [] def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]: usage = payload.get("usage") if isinstance(usage, dict): return {str(key): value for key, value in usage.items()} return {}