provider.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. import httpx
  2. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  3. from core_shared import JSONValue
  4. from app.bootstrap.settings import ModelGatewayServiceSettings
  5. class ModelProviderClientError(Exception):
  6. pass
  7. class ModelProviderClient:
  8. def __init__(self, *, settings: ModelGatewayServiceSettings) -> None:
  9. self.settings = settings
  10. def create_chat_completion(
  11. self,
  12. payload: ChatCompletionRequestContract,
  13. *,
  14. provider_type: str | None = None,
  15. provider_base_url: str | None = None,
  16. provider_api_key: str | None = None,
  17. timeout_seconds: float = 60.0,
  18. ) -> ChatCompletionResponseContract:
  19. if payload.model is None:
  20. raise ModelProviderClientError("model is required for chat completion")
  21. resolved_provider_type = provider_type or self.settings.provider_type
  22. if resolved_provider_type == "anthropic":
  23. return self._create_anthropic_message(
  24. payload,
  25. provider_base_url=provider_base_url,
  26. provider_api_key=provider_api_key,
  27. timeout_seconds=timeout_seconds)
  28. return self._create_openai_compatible_chat_completion(
  29. payload,
  30. provider_base_url=provider_base_url,
  31. provider_api_key=provider_api_key,
  32. timeout_seconds=timeout_seconds)
  33. def list_models(
  34. self,
  35. *,
  36. provider_type: str | None = None,
  37. provider_base_url: str | None = None,
  38. provider_api_key: str | None = None,
  39. timeout_seconds: float = 30.0,
  40. ) -> list[dict[str, JSONValue]]:
  41. resolved_provider_type = provider_type or self.settings.provider_type
  42. if resolved_provider_type == "anthropic":
  43. return self._list_anthropic_models(
  44. provider_base_url=provider_base_url,
  45. provider_api_key=provider_api_key,
  46. timeout_seconds=timeout_seconds)
  47. return self._list_openai_compatible_models(
  48. provider_base_url=provider_base_url,
  49. provider_api_key=provider_api_key,
  50. timeout_seconds=timeout_seconds)
  51. def _list_openai_compatible_models(
  52. self,
  53. *,
  54. provider_base_url: str | None,
  55. provider_api_key: str | None,
  56. timeout_seconds: float) -> list[dict[str, JSONValue]]:
  57. request_headers: dict[str, str] = {"content-type": "application/json"}
  58. api_key = (
  59. provider_api_key
  60. if provider_api_key is not None
  61. else self.settings.provider_api_key
  62. )
  63. if api_key:
  64. request_headers["authorization"] = f"Bearer {api_key}"
  65. try:
  66. base_url = provider_base_url or self.settings.provider_base_url
  67. with httpx.Client(timeout=timeout_seconds) as client:
  68. response = client.get(_join_url(base_url, "models"), headers=request_headers)
  69. response.raise_for_status()
  70. except httpx.HTTPStatusError as exc:
  71. detail = exc.response.text[:1000]
  72. raise ModelProviderClientError(
  73. f"model provider list models failed: {exc.response.status_code} {detail}") from exc
  74. except httpx.HTTPError as exc:
  75. raise ModelProviderClientError(f"model provider list models failed: {exc}") from exc
  76. return _extract_model_items(_coerce_json_dict(response.json()))
  77. def _list_anthropic_models(
  78. self,
  79. *,
  80. provider_base_url: str | None,
  81. provider_api_key: str | None,
  82. timeout_seconds: float) -> list[dict[str, JSONValue]]:
  83. api_key = (
  84. provider_api_key
  85. if provider_api_key is not None
  86. else self.settings.provider_api_key
  87. )
  88. if not api_key:
  89. raise ModelProviderClientError("anthropic api key is required")
  90. request_headers = {
  91. "content-type": "application/json",
  92. "x-api-key": api_key,
  93. "anthropic-version": "2023-06-01",
  94. }
  95. try:
  96. base_url = provider_base_url or self.settings.provider_base_url
  97. with httpx.Client(timeout=timeout_seconds) as client:
  98. response = client.get(_join_url(base_url, "v1/models"), headers=request_headers)
  99. response.raise_for_status()
  100. except httpx.HTTPStatusError as exc:
  101. detail = exc.response.text[:1000]
  102. raise ModelProviderClientError(
  103. f"anthropic list models failed: {exc.response.status_code} {detail}") from exc
  104. except httpx.HTTPError as exc:
  105. raise ModelProviderClientError(f"anthropic list models failed: {exc}") from exc
  106. return _extract_model_items(_coerce_json_dict(response.json()))
  107. def _create_openai_compatible_chat_completion(
  108. self,
  109. payload: ChatCompletionRequestContract,
  110. *,
  111. provider_base_url: str | None,
  112. provider_api_key: str | None,
  113. timeout_seconds: float) -> ChatCompletionResponseContract:
  114. request_payload: dict[str, JSONValue] = {
  115. "model": payload.model or "",
  116. "messages": [item.model_dump(mode="json") for item in payload.messages],
  117. }
  118. if payload.temperature is not None:
  119. request_payload["temperature"] = payload.temperature
  120. if payload.max_tokens is not None:
  121. request_payload["max_tokens"] = payload.max_tokens
  122. if payload.tools_json:
  123. request_payload["tools"] = payload.tools_json
  124. if payload.tool_choice is not None:
  125. request_payload["tool_choice"] = payload.tool_choice
  126. request_headers: dict[str, str] = {"content-type": "application/json"}
  127. api_key = (
  128. provider_api_key
  129. if provider_api_key is not None
  130. else self.settings.provider_api_key
  131. )
  132. if api_key:
  133. request_headers["authorization"] = f"Bearer {api_key}"
  134. try:
  135. base_url = provider_base_url or self.settings.provider_base_url
  136. with httpx.Client(timeout=timeout_seconds) as client:
  137. response = client.post(
  138. _join_url(base_url, "chat/completions"),
  139. json=request_payload,
  140. headers=request_headers)
  141. response.raise_for_status()
  142. except httpx.HTTPStatusError as exc:
  143. detail = exc.response.text[:1000]
  144. raise ModelProviderClientError(
  145. f"model provider request failed: {exc.response.status_code} {detail}") from exc
  146. except httpx.HTTPError as exc:
  147. raise ModelProviderClientError(f"model provider request failed: {exc}") from exc
  148. response_json = _coerce_json_dict(response.json())
  149. content = _extract_response_content(response_json)
  150. finish_reason = _extract_finish_reason(response_json)
  151. tool_calls_json = _extract_tool_calls_json(response_json)
  152. usage_json = _extract_usage_json(response_json)
  153. return ChatCompletionResponseContract(
  154. model=payload.model,
  155. content=content,
  156. finish_reason=finish_reason,
  157. tool_calls_json=tool_calls_json,
  158. usage_json=usage_json,
  159. raw_response_json=response_json)
  160. def _create_anthropic_message(
  161. self,
  162. payload: ChatCompletionRequestContract,
  163. *,
  164. provider_base_url: str | None,
  165. provider_api_key: str | None,
  166. timeout_seconds: float) -> ChatCompletionResponseContract:
  167. api_key = (
  168. provider_api_key
  169. if provider_api_key is not None
  170. else self.settings.provider_api_key
  171. )
  172. if not api_key:
  173. raise ModelProviderClientError("anthropic api key is required")
  174. system_prompt, messages = _to_anthropic_messages(payload)
  175. request_payload: dict[str, JSONValue] = {
  176. "model": payload.model or "",
  177. "max_tokens": payload.max_tokens or 1024,
  178. "messages": messages,
  179. }
  180. if system_prompt:
  181. request_payload["system"] = system_prompt
  182. if payload.temperature is not None:
  183. request_payload["temperature"] = payload.temperature
  184. request_headers = {
  185. "content-type": "application/json",
  186. "x-api-key": api_key,
  187. "anthropic-version": "2023-06-01",
  188. }
  189. try:
  190. base_url = provider_base_url or self.settings.provider_base_url
  191. with httpx.Client(timeout=timeout_seconds) as client:
  192. response = client.post(
  193. _join_url(base_url, "v1/messages"),
  194. json=request_payload,
  195. headers=request_headers)
  196. response.raise_for_status()
  197. except httpx.HTTPStatusError as exc:
  198. detail = exc.response.text[:1000]
  199. raise ModelProviderClientError(
  200. f"anthropic request failed: {exc.response.status_code} {detail}") from exc
  201. except httpx.HTTPError as exc:
  202. raise ModelProviderClientError(f"anthropic request failed: {exc}") from exc
  203. response_json = _coerce_json_dict(response.json())
  204. return ChatCompletionResponseContract(
  205. model=_read_string(response_json, "model") or payload.model,
  206. content=_extract_anthropic_content(response_json),
  207. finish_reason=_read_string(response_json, "stop_reason"),
  208. tool_calls_json=[],
  209. usage_json=_extract_usage_json(response_json),
  210. raw_response_json=response_json)
  211. def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]:
  212. if isinstance(payload, dict):
  213. return {str(key): value for key, value in payload.items()}
  214. return {}
  215. def _join_url(base_url: str, path: str) -> str:
  216. normalized_base = base_url.rstrip("/")
  217. normalized_path = path.strip("/")
  218. if normalized_path.startswith("v1/") and normalized_base.endswith("/v1"):
  219. normalized_path = normalized_path.removeprefix("v1/")
  220. return f"{normalized_base}/{normalized_path}"
  221. def _to_anthropic_messages(
  222. payload: ChatCompletionRequestContract) -> tuple[str | None, list[dict[str, JSONValue]]]:
  223. system_parts: list[str] = []
  224. messages: list[dict[str, JSONValue]] = []
  225. for message in payload.messages:
  226. if message.role == "system":
  227. system_parts.append(message.content)
  228. continue
  229. role = "assistant" if message.role == "assistant" else "user"
  230. if messages and messages[-1].get("role") == role:
  231. previous = messages[-1].get("content")
  232. messages[-1]["content"] = f"{previous}\n\n{message.content}" if isinstance(previous, str) else message.content
  233. else:
  234. messages.append({"role": role, "content": message.content})
  235. if not messages:
  236. messages.append({"role": "user", "content": ""})
  237. return ("\n\n".join(system_parts) if system_parts else None), messages
  238. def _extract_anthropic_content(payload: dict[str, JSONValue]) -> str:
  239. content = payload.get("content")
  240. if isinstance(content, str):
  241. return content
  242. if not isinstance(content, list):
  243. return ""
  244. parts: list[str] = []
  245. for item in content:
  246. if not isinstance(item, dict):
  247. continue
  248. text = item.get("text")
  249. if isinstance(text, str):
  250. parts.append(text)
  251. return "\n".join(parts)
  252. def _read_string(payload: dict[str, JSONValue], key: str) -> str | None:
  253. value = payload.get(key)
  254. return value if isinstance(value, str) else None
  255. def _extract_model_items(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
  256. data = payload.get("data")
  257. if not isinstance(data, list):
  258. data = payload.get("models")
  259. if not isinstance(data, list):
  260. return []
  261. items: list[dict[str, JSONValue]] = []
  262. for item in data:
  263. if isinstance(item, str):
  264. model_id = item
  265. display_name = item
  266. owned_by = None
  267. elif isinstance(item, dict):
  268. model_id = _read_string(item, "id") or _read_string(item, "model") or _read_string(item, "name")
  269. if model_id is None:
  270. continue
  271. display_name = _read_string(item, "display_name") or _read_string(item, "displayName") or model_id
  272. owned_by = _read_string(item, "owned_by") or _read_string(item, "ownedBy")
  273. else:
  274. continue
  275. model_item: dict[str, JSONValue] = {
  276. "modelId": model_id,
  277. "displayName": display_name,
  278. "modelType": _infer_model_type(model_id),
  279. }
  280. if owned_by:
  281. model_item["ownedBy"] = owned_by
  282. items.append(model_item)
  283. return items
  284. def _infer_model_type(model_id: str) -> str:
  285. normalized = model_id.lower()
  286. if "embed" in normalized or "embedding" in normalized:
  287. return "embedding"
  288. if "rerank" in normalized or "ranker" in normalized:
  289. return "rerank"
  290. if "moderation" in normalized:
  291. return "moderation"
  292. if "image" in normalized or "vision" in normalized:
  293. return "image"
  294. if "audio" in normalized or "whisper" in normalized or "tts" in normalized:
  295. return "audio"
  296. if "reason" in normalized or "thinking" in normalized or normalized.endswith("-r1") or "-r1-" in normalized:
  297. return "reasoning"
  298. return "chat"
  299. def _extract_response_content(payload: dict[str, JSONValue]) -> str:
  300. choices = payload.get("choices")
  301. if isinstance(choices, list) and choices:
  302. first_choice = choices[0]
  303. if isinstance(first_choice, dict):
  304. message = first_choice.get("message")
  305. if isinstance(message, dict):
  306. content = message.get("content")
  307. if isinstance(content, str):
  308. return content
  309. text = first_choice.get("text")
  310. if isinstance(text, str):
  311. return text
  312. return ""
  313. def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
  314. choices = payload.get("choices")
  315. if isinstance(choices, list) and choices:
  316. first_choice = choices[0]
  317. if isinstance(first_choice, dict):
  318. finish_reason = first_choice.get("finish_reason")
  319. if isinstance(finish_reason, str):
  320. return finish_reason
  321. return None
  322. def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
  323. choices = payload.get("choices")
  324. if isinstance(choices, list) and choices:
  325. first_choice = choices[0]
  326. if isinstance(first_choice, dict):
  327. message = first_choice.get("message")
  328. if isinstance(message, dict):
  329. tool_calls = message.get("tool_calls")
  330. if isinstance(tool_calls, list):
  331. return [
  332. {str(item_key): item_value for item_key, item_value in item.items()}
  333. for item in tool_calls
  334. if isinstance(item, dict)
  335. ]
  336. return []
  337. def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
  338. usage = payload.get("usage")
  339. if isinstance(usage, dict):
  340. return {str(key): value for key, value in usage.items()}
  341. return {}