| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- import json
- from collections.abc import Iterator
- import httpx
- from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
- class ModelGatewayClientError(Exception):
- pass
- class ModelGatewayClient:
- def __init__(self, *, base_url: str, timeout_seconds: float = 60.0) -> None:
- self.base_url = base_url.rstrip("/")
- self.timeout_seconds = timeout_seconds
- def create_chat_completion(
- self,
- payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
- try:
- with httpx.Client(timeout=self.timeout_seconds) as client:
- response = client.post(
- f"{self.base_url}/models/chat-completions",
- json=payload.model_dump(mode="json"))
- response.raise_for_status()
- return ChatCompletionResponseContract.model_validate(response.json())
- except httpx.HTTPError as exc:
- raise ModelGatewayClientError(f"model-gateway-service request failed: {exc}") from exc
- def stream_chat_completion(
- self,
- payload: ChatCompletionRequestContract) -> Iterator[str]:
- try:
- with httpx.Client(timeout=self.timeout_seconds) as client:
- with client.stream(
- "POST",
- f"{self.base_url}/models/chat-completions/stream",
- json=payload.model_dump(mode="json")) as response:
- response.raise_for_status()
- for event_name, data in _iter_sse_events(response):
- if event_name == "delta":
- delta = data.get("delta")
- if isinstance(delta, str):
- yield delta
- elif event_name == "error":
- message = data.get("message")
- raise ModelGatewayClientError(
- str(message) if isinstance(message, str) else "model-gateway stream failed")
- except httpx.HTTPError as exc:
- raise ModelGatewayClientError(f"model-gateway-service stream failed: {exc}") from exc
- def _iter_sse_events(response: httpx.Response) -> Iterator[tuple[str, dict[str, object]]]:
- event_name = "message"
- data_lines: list[str] = []
- for line in response.iter_lines():
- if line == "":
- if data_lines:
- yield event_name, _parse_json("\n".join(data_lines))
- event_name = "message"
- data_lines = []
- continue
- if line.startswith("event:"):
- event_name = line.removeprefix("event:").strip()
- elif line.startswith("data:"):
- data_lines.append(line.removeprefix("data:").strip())
- if data_lines:
- yield event_name, _parse_json("\n".join(data_lines))
- def _parse_json(value: str) -> dict[str, object]:
- try:
- payload = json.loads(value)
- except json.JSONDecodeError:
- return {}
- return payload if isinstance(payload, dict) else {}
|