provider.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import httpx
  2. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  3. from core_shared import JSONValue
  4. from app.bootstrap.settings import ModelGatewayServiceSettings
  5. class ModelProviderClientError(Exception):
  6. pass
  7. class ModelProviderClient:
  8. def __init__(self, *, settings: ModelGatewayServiceSettings) -> None:
  9. self.settings = settings
  10. def create_chat_completion(
  11. self,
  12. payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
  13. if payload.model is None:
  14. raise ModelProviderClientError("model is required for chat completion")
  15. request_payload = {
  16. "model": payload.model,
  17. "messages": [item.model_dump(mode="json") for item in payload.messages],
  18. }
  19. if payload.temperature is not None:
  20. request_payload["temperature"] = payload.temperature
  21. if payload.max_tokens is not None:
  22. request_payload["max_tokens"] = payload.max_tokens
  23. if payload.tools_json:
  24. request_payload["tools"] = payload.tools_json
  25. if payload.tool_choice is not None:
  26. request_payload["tool_choice"] = payload.tool_choice
  27. request_headers: dict[str, str] = {"content-type": "application/json"}
  28. if self.settings.provider_api_key:
  29. request_headers["authorization"] = f"Bearer {self.settings.provider_api_key}"
  30. try:
  31. with httpx.Client(timeout=60.0) as client:
  32. response = client.post(
  33. f"{self.settings.provider_base_url.rstrip('/')}/chat/completions",
  34. json=request_payload,
  35. headers=request_headers)
  36. response.raise_for_status()
  37. except httpx.HTTPError as exc:
  38. raise ModelProviderClientError(f"model provider request failed: {exc}") from exc
  39. response_json = _coerce_json_dict(response.json())
  40. content = _extract_response_content(response_json)
  41. finish_reason = _extract_finish_reason(response_json)
  42. tool_calls_json = _extract_tool_calls_json(response_json)
  43. usage_json = _extract_usage_json(response_json)
  44. return ChatCompletionResponseContract(
  45. model=payload.model,
  46. content=content,
  47. finish_reason=finish_reason,
  48. tool_calls_json=tool_calls_json,
  49. usage_json=usage_json,
  50. raw_response_json=response_json)
  51. def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]:
  52. if isinstance(payload, dict):
  53. return {str(key): value for key, value in payload.items()}
  54. return {}
  55. def _extract_response_content(payload: dict[str, JSONValue]) -> str:
  56. choices = payload.get("choices")
  57. if isinstance(choices, list) and choices:
  58. first_choice = choices[0]
  59. if isinstance(first_choice, dict):
  60. message = first_choice.get("message")
  61. if isinstance(message, dict):
  62. content = message.get("content")
  63. if isinstance(content, str):
  64. return content
  65. text = first_choice.get("text")
  66. if isinstance(text, str):
  67. return text
  68. return ""
  69. def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
  70. choices = payload.get("choices")
  71. if isinstance(choices, list) and choices:
  72. first_choice = choices[0]
  73. if isinstance(first_choice, dict):
  74. finish_reason = first_choice.get("finish_reason")
  75. if isinstance(finish_reason, str):
  76. return finish_reason
  77. return None
  78. def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
  79. choices = payload.get("choices")
  80. if isinstance(choices, list) and choices:
  81. first_choice = choices[0]
  82. if isinstance(first_choice, dict):
  83. message = first_choice.get("message")
  84. if isinstance(message, dict):
  85. tool_calls = message.get("tool_calls")
  86. if isinstance(tool_calls, list):
  87. return [
  88. {str(item_key): item_value for item_key, item_value in item.items()}
  89. for item in tool_calls
  90. if isinstance(item, dict)
  91. ]
  92. return []
  93. def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
  94. usage = payload.get("usage")
  95. if isinstance(usage, dict):
  96. return {str(key): value for key, value in usage.items()}
  97. return {}