model_gateway_client.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import json
  2. from collections.abc import Iterator
  3. import httpx
  4. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  5. class ModelGatewayClientError(Exception):
  6. pass
  7. class ModelGatewayClient:
  8. def __init__(self, *, base_url: str, timeout_seconds: float = 60.0) -> None:
  9. self.base_url = base_url.rstrip("/")
  10. self.timeout_seconds = timeout_seconds
  11. def create_chat_completion(
  12. self,
  13. payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
  14. try:
  15. with httpx.Client(timeout=self.timeout_seconds) as client:
  16. response = client.post(
  17. f"{self.base_url}/models/chat-completions",
  18. json=payload.model_dump(mode="json"))
  19. response.raise_for_status()
  20. return ChatCompletionResponseContract.model_validate(response.json())
  21. except httpx.HTTPError as exc:
  22. raise ModelGatewayClientError(f"model-gateway-service request failed: {exc}") from exc
  23. def stream_chat_completion(
  24. self,
  25. payload: ChatCompletionRequestContract) -> Iterator[str]:
  26. try:
  27. with httpx.Client(timeout=self.timeout_seconds) as client:
  28. with client.stream(
  29. "POST",
  30. f"{self.base_url}/models/chat-completions/stream",
  31. json=payload.model_dump(mode="json")) as response:
  32. response.raise_for_status()
  33. for event_name, data in _iter_sse_events(response):
  34. if event_name == "delta":
  35. delta = data.get("delta")
  36. if isinstance(delta, str):
  37. yield delta
  38. elif event_name == "error":
  39. message = data.get("message")
  40. raise ModelGatewayClientError(
  41. str(message) if isinstance(message, str) else "model-gateway stream failed")
  42. except httpx.HTTPError as exc:
  43. raise ModelGatewayClientError(f"model-gateway-service stream failed: {exc}") from exc
  44. def _iter_sse_events(response: httpx.Response) -> Iterator[tuple[str, dict[str, object]]]:
  45. event_name = "message"
  46. data_lines: list[str] = []
  47. for line in response.iter_lines():
  48. if line == "":
  49. if data_lines:
  50. yield event_name, _parse_json("\n".join(data_lines))
  51. event_name = "message"
  52. data_lines = []
  53. continue
  54. if line.startswith("event:"):
  55. event_name = line.removeprefix("event:").strip()
  56. elif line.startswith("data:"):
  57. data_lines.append(line.removeprefix("data:").strip())
  58. if data_lines:
  59. yield event_name, _parse_json("\n".join(data_lines))
  60. def _parse_json(value: str) -> dict[str, object]:
  61. try:
  62. payload = json.loads(value)
  63. except json.JSONDecodeError:
  64. return {}
  65. return payload if isinstance(payload, dict) else {}