import json from collections.abc import Iterator import httpx from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract class ModelGatewayClientError(Exception): pass class ModelGatewayClient: def __init__(self, *, base_url: str, timeout_seconds: float = 60.0) -> None: self.base_url = base_url.rstrip("/") self.timeout_seconds = timeout_seconds def create_chat_completion( self, payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract: try: with httpx.Client(timeout=self.timeout_seconds) as client: response = client.post( f"{self.base_url}/models/chat-completions", json=payload.model_dump(mode="json")) response.raise_for_status() return ChatCompletionResponseContract.model_validate(response.json()) except httpx.HTTPError as exc: raise ModelGatewayClientError(f"model-gateway-service request failed: {exc}") from exc def stream_chat_completion( self, payload: ChatCompletionRequestContract) -> Iterator[str]: try: with httpx.Client(timeout=self.timeout_seconds) as client: with client.stream( "POST", f"{self.base_url}/models/chat-completions/stream", json=payload.model_dump(mode="json")) as response: response.raise_for_status() for event_name, data in _iter_sse_events(response): if event_name == "delta": delta = data.get("delta") if isinstance(delta, str): yield delta elif event_name == "error": message = data.get("message") raise ModelGatewayClientError( str(message) if isinstance(message, str) else "model-gateway stream failed") except httpx.HTTPError as exc: raise ModelGatewayClientError(f"model-gateway-service stream failed: {exc}") from exc def _iter_sse_events(response: httpx.Response) -> Iterator[tuple[str, dict[str, object]]]: event_name = "message" data_lines: list[str] = [] for line in response.iter_lines(): if line == "": if data_lines: yield event_name, _parse_json("\n".join(data_lines)) event_name = "message" data_lines = [] continue if line.startswith("event:"): event_name = line.removeprefix("event:").strip() elif line.startswith("data:"): data_lines.append(line.removeprefix("data:").strip()) if data_lines: yield event_name, _parse_json("\n".join(data_lines)) def _parse_json(value: str) -> dict[str, object]: try: payload = json.loads(value) except json.JSONDecodeError: return {} return payload if isinstance(payload, dict) else {}