provider.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import httpx
  2. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  3. from core_shared import JSONValue
  4. from app.bootstrap.settings import ModelGatewayServiceSettings
  5. class ModelProviderClientError(Exception):
  6. pass
  7. class ModelProviderClient:
  8. def __init__(self, *, settings: ModelGatewayServiceSettings) -> None:
  9. self.settings = settings
  10. def create_chat_completion(
  11. self,
  12. payload: ChatCompletionRequestContract,
  13. ) -> ChatCompletionResponseContract:
  14. if payload.model is None:
  15. raise ModelProviderClientError("model is required for chat completion")
  16. request_payload = {
  17. "model": payload.model,
  18. "messages": [item.model_dump(mode="json") for item in payload.messages],
  19. }
  20. if payload.temperature is not None:
  21. request_payload["temperature"] = payload.temperature
  22. if payload.max_tokens is not None:
  23. request_payload["max_tokens"] = payload.max_tokens
  24. if payload.tools_json:
  25. request_payload["tools"] = payload.tools_json
  26. if payload.tool_choice is not None:
  27. request_payload["tool_choice"] = payload.tool_choice
  28. request_headers: dict[str, str] = {"content-type": "application/json"}
  29. if self.settings.provider_api_key:
  30. request_headers["authorization"] = f"Bearer {self.settings.provider_api_key}"
  31. try:
  32. with httpx.Client(timeout=60.0) as client:
  33. response = client.post(
  34. f"{self.settings.provider_base_url.rstrip('/')}/chat/completions",
  35. json=request_payload,
  36. headers=request_headers,
  37. )
  38. response.raise_for_status()
  39. except httpx.HTTPError as exc:
  40. raise ModelProviderClientError(f"model provider request failed: {exc}") from exc
  41. response_json = _coerce_json_dict(response.json())
  42. content = _extract_response_content(response_json)
  43. finish_reason = _extract_finish_reason(response_json)
  44. tool_calls_json = _extract_tool_calls_json(response_json)
  45. usage_json = _extract_usage_json(response_json)
  46. return ChatCompletionResponseContract(
  47. model=payload.model,
  48. content=content,
  49. finish_reason=finish_reason,
  50. tool_calls_json=tool_calls_json,
  51. usage_json=usage_json,
  52. raw_response_json=response_json,
  53. )
  54. def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]:
  55. if isinstance(payload, dict):
  56. return {str(key): value for key, value in payload.items()}
  57. return {}
  58. def _extract_response_content(payload: dict[str, JSONValue]) -> str:
  59. choices = payload.get("choices")
  60. if isinstance(choices, list) and choices:
  61. first_choice = choices[0]
  62. if isinstance(first_choice, dict):
  63. message = first_choice.get("message")
  64. if isinstance(message, dict):
  65. content = message.get("content")
  66. if isinstance(content, str):
  67. return content
  68. text = first_choice.get("text")
  69. if isinstance(text, str):
  70. return text
  71. return ""
  72. def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
  73. choices = payload.get("choices")
  74. if isinstance(choices, list) and choices:
  75. first_choice = choices[0]
  76. if isinstance(first_choice, dict):
  77. finish_reason = first_choice.get("finish_reason")
  78. if isinstance(finish_reason, str):
  79. return finish_reason
  80. return None
  81. def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
  82. choices = payload.get("choices")
  83. if isinstance(choices, list) and choices:
  84. first_choice = choices[0]
  85. if isinstance(first_choice, dict):
  86. message = first_choice.get("message")
  87. if isinstance(message, dict):
  88. tool_calls = message.get("tool_calls")
  89. if isinstance(tool_calls, list):
  90. return [
  91. {str(item_key): item_value for item_key, item_value in item.items()}
  92. for item in tool_calls
  93. if isinstance(item, dict)
  94. ]
  95. return []
  96. def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
  97. usage = payload.get("usage")
  98. if isinstance(usage, dict):
  99. return {str(key): value for key, value in usage.items()}
  100. return {}