provider.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import httpx
  2. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  3. from core_shared import JSONValue
  4. from app.bootstrap.settings import ModelGatewayServiceSettings
  5. class ModelProviderClientError(Exception):
  6. pass
  7. class ModelProviderClient:
  8. def __init__(self, *, settings: ModelGatewayServiceSettings) -> None:
  9. self.settings = settings
  10. def create_chat_completion(
  11. self,
  12. payload: ChatCompletionRequestContract,
  13. *,
  14. provider_type: str | None = None,
  15. provider_base_url: str | None = None,
  16. provider_api_key: str | None = None,
  17. timeout_seconds: float = 60.0,
  18. ) -> ChatCompletionResponseContract:
  19. if payload.model is None:
  20. raise ModelProviderClientError("model is required for chat completion")
  21. resolved_provider_type = provider_type or self.settings.provider_type
  22. if resolved_provider_type == "anthropic":
  23. return self._create_anthropic_message(
  24. payload,
  25. provider_base_url=provider_base_url,
  26. provider_api_key=provider_api_key,
  27. timeout_seconds=timeout_seconds)
  28. return self._create_openai_compatible_chat_completion(
  29. payload,
  30. provider_base_url=provider_base_url,
  31. provider_api_key=provider_api_key,
  32. timeout_seconds=timeout_seconds)
  33. def _create_openai_compatible_chat_completion(
  34. self,
  35. payload: ChatCompletionRequestContract,
  36. *,
  37. provider_base_url: str | None,
  38. provider_api_key: str | None,
  39. timeout_seconds: float) -> ChatCompletionResponseContract:
  40. request_payload: dict[str, JSONValue] = {
  41. "model": payload.model or "",
  42. "messages": [item.model_dump(mode="json") for item in payload.messages],
  43. }
  44. if payload.temperature is not None:
  45. request_payload["temperature"] = payload.temperature
  46. if payload.max_tokens is not None:
  47. request_payload["max_tokens"] = payload.max_tokens
  48. if payload.tools_json:
  49. request_payload["tools"] = payload.tools_json
  50. if payload.tool_choice is not None:
  51. request_payload["tool_choice"] = payload.tool_choice
  52. request_headers: dict[str, str] = {"content-type": "application/json"}
  53. api_key = (
  54. provider_api_key
  55. if provider_api_key is not None
  56. else self.settings.provider_api_key
  57. )
  58. if api_key:
  59. request_headers["authorization"] = f"Bearer {api_key}"
  60. try:
  61. base_url = provider_base_url or self.settings.provider_base_url
  62. with httpx.Client(timeout=timeout_seconds) as client:
  63. response = client.post(
  64. _join_url(base_url, "chat/completions"),
  65. json=request_payload,
  66. headers=request_headers)
  67. response.raise_for_status()
  68. except httpx.HTTPStatusError as exc:
  69. detail = exc.response.text[:1000]
  70. raise ModelProviderClientError(
  71. f"model provider request failed: {exc.response.status_code} {detail}") from exc
  72. except httpx.HTTPError as exc:
  73. raise ModelProviderClientError(f"model provider request failed: {exc}") from exc
  74. response_json = _coerce_json_dict(response.json())
  75. content = _extract_response_content(response_json)
  76. finish_reason = _extract_finish_reason(response_json)
  77. tool_calls_json = _extract_tool_calls_json(response_json)
  78. usage_json = _extract_usage_json(response_json)
  79. return ChatCompletionResponseContract(
  80. model=payload.model,
  81. content=content,
  82. finish_reason=finish_reason,
  83. tool_calls_json=tool_calls_json,
  84. usage_json=usage_json,
  85. raw_response_json=response_json)
  86. def _create_anthropic_message(
  87. self,
  88. payload: ChatCompletionRequestContract,
  89. *,
  90. provider_base_url: str | None,
  91. provider_api_key: str | None,
  92. timeout_seconds: float) -> ChatCompletionResponseContract:
  93. api_key = (
  94. provider_api_key
  95. if provider_api_key is not None
  96. else self.settings.provider_api_key
  97. )
  98. if not api_key:
  99. raise ModelProviderClientError("anthropic api key is required")
  100. system_prompt, messages = _to_anthropic_messages(payload)
  101. request_payload: dict[str, JSONValue] = {
  102. "model": payload.model or "",
  103. "max_tokens": payload.max_tokens or 1024,
  104. "messages": messages,
  105. }
  106. if system_prompt:
  107. request_payload["system"] = system_prompt
  108. if payload.temperature is not None:
  109. request_payload["temperature"] = payload.temperature
  110. request_headers = {
  111. "content-type": "application/json",
  112. "x-api-key": api_key,
  113. "anthropic-version": "2023-06-01",
  114. }
  115. try:
  116. base_url = provider_base_url or self.settings.provider_base_url
  117. with httpx.Client(timeout=timeout_seconds) as client:
  118. response = client.post(
  119. _join_url(base_url, "v1/messages"),
  120. json=request_payload,
  121. headers=request_headers)
  122. response.raise_for_status()
  123. except httpx.HTTPStatusError as exc:
  124. detail = exc.response.text[:1000]
  125. raise ModelProviderClientError(
  126. f"anthropic request failed: {exc.response.status_code} {detail}") from exc
  127. except httpx.HTTPError as exc:
  128. raise ModelProviderClientError(f"anthropic request failed: {exc}") from exc
  129. response_json = _coerce_json_dict(response.json())
  130. return ChatCompletionResponseContract(
  131. model=_read_string(response_json, "model") or payload.model,
  132. content=_extract_anthropic_content(response_json),
  133. finish_reason=_read_string(response_json, "stop_reason"),
  134. tool_calls_json=[],
  135. usage_json=_extract_usage_json(response_json),
  136. raw_response_json=response_json)
  137. def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]:
  138. if isinstance(payload, dict):
  139. return {str(key): value for key, value in payload.items()}
  140. return {}
  141. def _join_url(base_url: str, path: str) -> str:
  142. normalized_base = base_url.rstrip("/")
  143. normalized_path = path.strip("/")
  144. if normalized_path.startswith("v1/") and normalized_base.endswith("/v1"):
  145. normalized_path = normalized_path.removeprefix("v1/")
  146. return f"{normalized_base}/{normalized_path}"
  147. def _to_anthropic_messages(
  148. payload: ChatCompletionRequestContract) -> tuple[str | None, list[dict[str, JSONValue]]]:
  149. system_parts: list[str] = []
  150. messages: list[dict[str, JSONValue]] = []
  151. for message in payload.messages:
  152. if message.role == "system":
  153. system_parts.append(message.content)
  154. continue
  155. role = "assistant" if message.role == "assistant" else "user"
  156. if messages and messages[-1].get("role") == role:
  157. previous = messages[-1].get("content")
  158. messages[-1]["content"] = f"{previous}\n\n{message.content}" if isinstance(previous, str) else message.content
  159. else:
  160. messages.append({"role": role, "content": message.content})
  161. if not messages:
  162. messages.append({"role": "user", "content": ""})
  163. return ("\n\n".join(system_parts) if system_parts else None), messages
  164. def _extract_anthropic_content(payload: dict[str, JSONValue]) -> str:
  165. content = payload.get("content")
  166. if isinstance(content, str):
  167. return content
  168. if not isinstance(content, list):
  169. return ""
  170. parts: list[str] = []
  171. for item in content:
  172. if not isinstance(item, dict):
  173. continue
  174. text = item.get("text")
  175. if isinstance(text, str):
  176. parts.append(text)
  177. return "\n".join(parts)
  178. def _read_string(payload: dict[str, JSONValue], key: str) -> str | None:
  179. value = payload.get(key)
  180. return value if isinstance(value, str) else None
  181. def _extract_response_content(payload: dict[str, JSONValue]) -> str:
  182. choices = payload.get("choices")
  183. if isinstance(choices, list) and choices:
  184. first_choice = choices[0]
  185. if isinstance(first_choice, dict):
  186. message = first_choice.get("message")
  187. if isinstance(message, dict):
  188. content = message.get("content")
  189. if isinstance(content, str):
  190. return content
  191. text = first_choice.get("text")
  192. if isinstance(text, str):
  193. return text
  194. return ""
  195. def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
  196. choices = payload.get("choices")
  197. if isinstance(choices, list) and choices:
  198. first_choice = choices[0]
  199. if isinstance(first_choice, dict):
  200. finish_reason = first_choice.get("finish_reason")
  201. if isinstance(finish_reason, str):
  202. return finish_reason
  203. return None
  204. def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
  205. choices = payload.get("choices")
  206. if isinstance(choices, list) and choices:
  207. first_choice = choices[0]
  208. if isinstance(first_choice, dict):
  209. message = first_choice.get("message")
  210. if isinstance(message, dict):
  211. tool_calls = message.get("tool_calls")
  212. if isinstance(tool_calls, list):
  213. return [
  214. {str(item_key): item_value for item_key, item_value in item.items()}
  215. for item in tool_calls
  216. if isinstance(item, dict)
  217. ]
  218. return []
  219. def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
  220. usage = payload.get("usage")
  221. if isinstance(usage, dict):
  222. return {str(key): value for key, value in usage.items()}
  223. return {}