provider.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. import json
  2. from collections.abc import Iterator
  3. import httpx
  4. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  5. from core_shared import JSONValue
  6. from app.bootstrap.settings import ModelGatewayServiceSettings
  7. class ModelProviderClientError(Exception):
  8. pass
  9. class ModelProviderClient:
  10. def __init__(self, *, settings: ModelGatewayServiceSettings) -> None:
  11. self.settings = settings
  12. def create_chat_completion(
  13. self,
  14. payload: ChatCompletionRequestContract,
  15. *,
  16. provider_type: str | None = None,
  17. provider_base_url: str | None = None,
  18. provider_api_key: str | None = None,
  19. timeout_seconds: float = 60.0,
  20. ) -> ChatCompletionResponseContract:
  21. if payload.model is None:
  22. raise ModelProviderClientError("model is required for chat completion")
  23. resolved_provider_type = provider_type or self.settings.provider_type
  24. if resolved_provider_type == "anthropic":
  25. return self._create_anthropic_message(
  26. payload,
  27. provider_base_url=provider_base_url,
  28. provider_api_key=provider_api_key,
  29. timeout_seconds=timeout_seconds)
  30. return self._create_openai_compatible_chat_completion(
  31. payload,
  32. provider_base_url=provider_base_url,
  33. provider_api_key=provider_api_key,
  34. timeout_seconds=timeout_seconds)
  35. def stream_chat_completion(
  36. self,
  37. payload: ChatCompletionRequestContract,
  38. *,
  39. provider_type: str | None = None,
  40. provider_base_url: str | None = None,
  41. provider_api_key: str | None = None,
  42. timeout_seconds: float = 60.0,
  43. ) -> Iterator[str]:
  44. if payload.model is None:
  45. raise ModelProviderClientError("model is required for chat completion")
  46. resolved_provider_type = provider_type or self.settings.provider_type
  47. if resolved_provider_type == "anthropic":
  48. yield from self._stream_anthropic_message(
  49. payload,
  50. provider_base_url=provider_base_url,
  51. provider_api_key=provider_api_key,
  52. timeout_seconds=timeout_seconds)
  53. return
  54. yield from self._stream_openai_compatible_chat_completion(
  55. payload,
  56. provider_base_url=provider_base_url,
  57. provider_api_key=provider_api_key,
  58. timeout_seconds=timeout_seconds)
  59. def list_models(
  60. self,
  61. *,
  62. provider_type: str | None = None,
  63. provider_base_url: str | None = None,
  64. provider_api_key: str | None = None,
  65. timeout_seconds: float = 30.0,
  66. ) -> list[dict[str, JSONValue]]:
  67. resolved_provider_type = provider_type or self.settings.provider_type
  68. if resolved_provider_type == "anthropic":
  69. return self._list_anthropic_models(
  70. provider_base_url=provider_base_url,
  71. provider_api_key=provider_api_key,
  72. timeout_seconds=timeout_seconds)
  73. return self._list_openai_compatible_models(
  74. provider_base_url=provider_base_url,
  75. provider_api_key=provider_api_key,
  76. timeout_seconds=timeout_seconds)
  77. def _list_openai_compatible_models(
  78. self,
  79. *,
  80. provider_base_url: str | None,
  81. provider_api_key: str | None,
  82. timeout_seconds: float) -> list[dict[str, JSONValue]]:
  83. request_headers: dict[str, str] = {"content-type": "application/json"}
  84. api_key = (
  85. provider_api_key
  86. if provider_api_key is not None
  87. else self.settings.provider_api_key
  88. )
  89. if api_key:
  90. request_headers["authorization"] = f"Bearer {api_key}"
  91. try:
  92. base_url = provider_base_url or self.settings.provider_base_url
  93. with httpx.Client(timeout=timeout_seconds) as client:
  94. response = client.get(_join_url(base_url, "models"), headers=request_headers)
  95. response.raise_for_status()
  96. except httpx.HTTPStatusError as exc:
  97. detail = exc.response.text[:1000]
  98. raise ModelProviderClientError(
  99. f"model provider list models failed: {exc.response.status_code} {detail}") from exc
  100. except httpx.HTTPError as exc:
  101. raise ModelProviderClientError(f"model provider list models failed: {exc}") from exc
  102. return _extract_model_items(_coerce_json_dict(response.json()))
  103. def _list_anthropic_models(
  104. self,
  105. *,
  106. provider_base_url: str | None,
  107. provider_api_key: str | None,
  108. timeout_seconds: float) -> list[dict[str, JSONValue]]:
  109. api_key = (
  110. provider_api_key
  111. if provider_api_key is not None
  112. else self.settings.provider_api_key
  113. )
  114. if not api_key:
  115. raise ModelProviderClientError("anthropic api key is required")
  116. request_headers = {
  117. "content-type": "application/json",
  118. "x-api-key": api_key,
  119. "anthropic-version": "2023-06-01",
  120. }
  121. try:
  122. base_url = provider_base_url or self.settings.provider_base_url
  123. with httpx.Client(timeout=timeout_seconds) as client:
  124. response = client.get(_join_url(base_url, "v1/models"), headers=request_headers)
  125. response.raise_for_status()
  126. except httpx.HTTPStatusError as exc:
  127. detail = exc.response.text[:1000]
  128. raise ModelProviderClientError(
  129. f"anthropic list models failed: {exc.response.status_code} {detail}") from exc
  130. except httpx.HTTPError as exc:
  131. raise ModelProviderClientError(f"anthropic list models failed: {exc}") from exc
  132. return _extract_model_items(_coerce_json_dict(response.json()))
  133. def _create_openai_compatible_chat_completion(
  134. self,
  135. payload: ChatCompletionRequestContract,
  136. *,
  137. provider_base_url: str | None,
  138. provider_api_key: str | None,
  139. timeout_seconds: float) -> ChatCompletionResponseContract:
  140. request_payload: dict[str, JSONValue] = {
  141. "model": payload.model or "",
  142. "messages": [item.model_dump(mode="json") for item in payload.messages],
  143. }
  144. if payload.temperature is not None:
  145. request_payload["temperature"] = payload.temperature
  146. if payload.max_tokens is not None:
  147. request_payload["max_tokens"] = payload.max_tokens
  148. if payload.tools_json:
  149. request_payload["tools"] = payload.tools_json
  150. if payload.tool_choice is not None:
  151. request_payload["tool_choice"] = payload.tool_choice
  152. request_headers: dict[str, str] = {"content-type": "application/json"}
  153. api_key = (
  154. provider_api_key
  155. if provider_api_key is not None
  156. else self.settings.provider_api_key
  157. )
  158. if api_key:
  159. request_headers["authorization"] = f"Bearer {api_key}"
  160. try:
  161. base_url = provider_base_url or self.settings.provider_base_url
  162. with httpx.Client(timeout=timeout_seconds) as client:
  163. response = client.post(
  164. _join_url(base_url, "chat/completions"),
  165. json=request_payload,
  166. headers=request_headers)
  167. response.raise_for_status()
  168. except httpx.HTTPStatusError as exc:
  169. detail = exc.response.text[:1000]
  170. raise ModelProviderClientError(
  171. f"model provider request failed: {exc.response.status_code} {detail}") from exc
  172. except httpx.HTTPError as exc:
  173. raise ModelProviderClientError(f"model provider request failed: {exc}") from exc
  174. response_json = _coerce_json_dict(response.json())
  175. content = _extract_response_content(response_json)
  176. finish_reason = _extract_finish_reason(response_json)
  177. tool_calls_json = _extract_tool_calls_json(response_json)
  178. usage_json = _extract_usage_json(response_json)
  179. return ChatCompletionResponseContract(
  180. model=payload.model,
  181. content=content,
  182. finish_reason=finish_reason,
  183. tool_calls_json=tool_calls_json,
  184. usage_json=usage_json,
  185. raw_response_json=response_json)
  186. def _stream_openai_compatible_chat_completion(
  187. self,
  188. payload: ChatCompletionRequestContract,
  189. *,
  190. provider_base_url: str | None,
  191. provider_api_key: str | None,
  192. timeout_seconds: float) -> Iterator[str]:
  193. request_payload = _build_openai_request_payload(payload)
  194. request_payload["stream"] = True
  195. request_headers = _build_openai_headers(
  196. settings=self.settings,
  197. provider_api_key=provider_api_key)
  198. try:
  199. base_url = provider_base_url or self.settings.provider_base_url
  200. with httpx.Client(timeout=timeout_seconds) as client:
  201. with client.stream(
  202. "POST",
  203. _join_url(base_url, "chat/completions"),
  204. json=request_payload,
  205. headers=request_headers) as response:
  206. response.raise_for_status()
  207. for line in response.iter_lines():
  208. if not line.startswith("data:"):
  209. continue
  210. data = line.removeprefix("data:").strip()
  211. if data == "[DONE]":
  212. break
  213. try:
  214. payload_json = _coerce_json_dict(json.loads(data))
  215. except json.JSONDecodeError:
  216. continue
  217. delta = _extract_openai_stream_delta(payload_json)
  218. if delta:
  219. yield delta
  220. except httpx.HTTPStatusError as exc:
  221. detail = exc.response.text[:1000]
  222. raise ModelProviderClientError(
  223. f"model provider stream failed: {exc.response.status_code} {detail}") from exc
  224. except httpx.HTTPError as exc:
  225. raise ModelProviderClientError(f"model provider stream failed: {exc}") from exc
  226. def _create_anthropic_message(
  227. self,
  228. payload: ChatCompletionRequestContract,
  229. *,
  230. provider_base_url: str | None,
  231. provider_api_key: str | None,
  232. timeout_seconds: float) -> ChatCompletionResponseContract:
  233. api_key = (
  234. provider_api_key
  235. if provider_api_key is not None
  236. else self.settings.provider_api_key
  237. )
  238. if not api_key:
  239. raise ModelProviderClientError("anthropic api key is required")
  240. system_prompt, messages = _to_anthropic_messages(payload)
  241. request_payload: dict[str, JSONValue] = {
  242. "model": payload.model or "",
  243. "max_tokens": payload.max_tokens or 1024,
  244. "messages": messages,
  245. }
  246. if system_prompt:
  247. request_payload["system"] = system_prompt
  248. if payload.temperature is not None:
  249. request_payload["temperature"] = payload.temperature
  250. request_headers = {
  251. "content-type": "application/json",
  252. "x-api-key": api_key,
  253. "anthropic-version": "2023-06-01",
  254. }
  255. try:
  256. base_url = provider_base_url or self.settings.provider_base_url
  257. with httpx.Client(timeout=timeout_seconds) as client:
  258. response = client.post(
  259. _join_url(base_url, "v1/messages"),
  260. json=request_payload,
  261. headers=request_headers)
  262. response.raise_for_status()
  263. except httpx.HTTPStatusError as exc:
  264. detail = exc.response.text[:1000]
  265. raise ModelProviderClientError(
  266. f"anthropic request failed: {exc.response.status_code} {detail}") from exc
  267. except httpx.HTTPError as exc:
  268. raise ModelProviderClientError(f"anthropic request failed: {exc}") from exc
  269. response_json = _coerce_json_dict(response.json())
  270. return ChatCompletionResponseContract(
  271. model=_read_string(response_json, "model") or payload.model,
  272. content=_extract_anthropic_content(response_json),
  273. finish_reason=_read_string(response_json, "stop_reason"),
  274. tool_calls_json=[],
  275. usage_json=_extract_usage_json(response_json),
  276. raw_response_json=response_json)
  277. def _stream_anthropic_message(
  278. self,
  279. payload: ChatCompletionRequestContract,
  280. *,
  281. provider_base_url: str | None,
  282. provider_api_key: str | None,
  283. timeout_seconds: float) -> Iterator[str]:
  284. api_key = (
  285. provider_api_key
  286. if provider_api_key is not None
  287. else self.settings.provider_api_key
  288. )
  289. if not api_key:
  290. raise ModelProviderClientError("anthropic api key is required")
  291. system_prompt, messages = _to_anthropic_messages(payload)
  292. request_payload: dict[str, JSONValue] = {
  293. "model": payload.model or "",
  294. "max_tokens": payload.max_tokens or 1024,
  295. "messages": messages,
  296. "stream": True,
  297. }
  298. if system_prompt:
  299. request_payload["system"] = system_prompt
  300. if payload.temperature is not None:
  301. request_payload["temperature"] = payload.temperature
  302. request_headers = {
  303. "content-type": "application/json",
  304. "x-api-key": api_key,
  305. "anthropic-version": "2023-06-01",
  306. }
  307. try:
  308. base_url = provider_base_url or self.settings.provider_base_url
  309. with httpx.Client(timeout=timeout_seconds) as client:
  310. with client.stream(
  311. "POST",
  312. _join_url(base_url, "v1/messages"),
  313. json=request_payload,
  314. headers=request_headers) as response:
  315. response.raise_for_status()
  316. for line in response.iter_lines():
  317. if not line.startswith("data:"):
  318. continue
  319. data = line.removeprefix("data:").strip()
  320. try:
  321. payload_json = _coerce_json_dict(json.loads(data))
  322. except json.JSONDecodeError:
  323. continue
  324. delta = _extract_anthropic_stream_delta(payload_json)
  325. if delta:
  326. yield delta
  327. except httpx.HTTPStatusError as exc:
  328. detail = exc.response.text[:1000]
  329. raise ModelProviderClientError(
  330. f"anthropic stream failed: {exc.response.status_code} {detail}") from exc
  331. except httpx.HTTPError as exc:
  332. raise ModelProviderClientError(f"anthropic stream failed: {exc}") from exc
  333. def _coerce_json_dict(payload: JSONValue) -> dict[str, JSONValue]:
  334. if isinstance(payload, dict):
  335. return {str(key): value for key, value in payload.items()}
  336. return {}
  337. def _build_openai_request_payload(
  338. payload: ChatCompletionRequestContract) -> dict[str, JSONValue]:
  339. request_payload: dict[str, JSONValue] = {
  340. "model": payload.model or "",
  341. "messages": [item.model_dump(mode="json") for item in payload.messages],
  342. }
  343. if payload.temperature is not None:
  344. request_payload["temperature"] = payload.temperature
  345. if payload.max_tokens is not None:
  346. request_payload["max_tokens"] = payload.max_tokens
  347. if payload.tools_json:
  348. request_payload["tools"] = payload.tools_json
  349. if payload.tool_choice is not None:
  350. request_payload["tool_choice"] = payload.tool_choice
  351. return request_payload
  352. def _build_openai_headers(
  353. *,
  354. settings: ModelGatewayServiceSettings,
  355. provider_api_key: str | None) -> dict[str, str]:
  356. request_headers: dict[str, str] = {"content-type": "application/json"}
  357. api_key = (
  358. provider_api_key
  359. if provider_api_key is not None
  360. else settings.provider_api_key
  361. )
  362. if api_key:
  363. request_headers["authorization"] = f"Bearer {api_key}"
  364. return request_headers
  365. def _join_url(base_url: str, path: str) -> str:
  366. normalized_base = base_url.rstrip("/")
  367. normalized_path = path.strip("/")
  368. if normalized_path.startswith("v1/") and normalized_base.endswith("/v1"):
  369. normalized_path = normalized_path.removeprefix("v1/")
  370. return f"{normalized_base}/{normalized_path}"
  371. def _to_anthropic_messages(
  372. payload: ChatCompletionRequestContract) -> tuple[str | None, list[dict[str, JSONValue]]]:
  373. system_parts: list[str] = []
  374. messages: list[dict[str, JSONValue]] = []
  375. for message in payload.messages:
  376. if message.role == "system":
  377. system_parts.append(message.content)
  378. continue
  379. role = "assistant" if message.role == "assistant" else "user"
  380. if messages and messages[-1].get("role") == role:
  381. previous = messages[-1].get("content")
  382. messages[-1]["content"] = f"{previous}\n\n{message.content}" if isinstance(previous, str) else message.content
  383. else:
  384. messages.append({"role": role, "content": message.content})
  385. if not messages:
  386. messages.append({"role": "user", "content": ""})
  387. return ("\n\n".join(system_parts) if system_parts else None), messages
  388. def _extract_anthropic_content(payload: dict[str, JSONValue]) -> str:
  389. content = payload.get("content")
  390. if isinstance(content, str):
  391. return content
  392. if not isinstance(content, list):
  393. return ""
  394. parts: list[str] = []
  395. for item in content:
  396. if not isinstance(item, dict):
  397. continue
  398. text = item.get("text")
  399. if isinstance(text, str):
  400. parts.append(text)
  401. return "\n".join(parts)
  402. def _read_string(payload: dict[str, JSONValue], key: str) -> str | None:
  403. value = payload.get(key)
  404. return value if isinstance(value, str) else None
  405. def _extract_model_items(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
  406. data = payload.get("data")
  407. if not isinstance(data, list):
  408. data = payload.get("models")
  409. if not isinstance(data, list):
  410. return []
  411. items: list[dict[str, JSONValue]] = []
  412. for item in data:
  413. if isinstance(item, str):
  414. model_id = item
  415. display_name = item
  416. owned_by = None
  417. elif isinstance(item, dict):
  418. model_id = _read_string(item, "id") or _read_string(item, "model") or _read_string(item, "name")
  419. if model_id is None:
  420. continue
  421. display_name = _read_string(item, "display_name") or _read_string(item, "displayName") or model_id
  422. owned_by = _read_string(item, "owned_by") or _read_string(item, "ownedBy")
  423. else:
  424. continue
  425. model_item: dict[str, JSONValue] = {
  426. "modelId": model_id,
  427. "displayName": display_name,
  428. "modelType": _infer_model_type(model_id),
  429. }
  430. if owned_by:
  431. model_item["ownedBy"] = owned_by
  432. items.append(model_item)
  433. return items
  434. def _infer_model_type(model_id: str) -> str:
  435. normalized = model_id.lower()
  436. if "embed" in normalized or "embedding" in normalized:
  437. return "embedding"
  438. if "rerank" in normalized or "ranker" in normalized:
  439. return "rerank"
  440. if "moderation" in normalized:
  441. return "moderation"
  442. if "image" in normalized or "vision" in normalized:
  443. return "image"
  444. if "audio" in normalized or "whisper" in normalized or "tts" in normalized:
  445. return "audio"
  446. if "reason" in normalized or "thinking" in normalized or normalized.endswith("-r1") or "-r1-" in normalized:
  447. return "reasoning"
  448. return "chat"
  449. def _extract_response_content(payload: dict[str, JSONValue]) -> str:
  450. choices = payload.get("choices")
  451. if isinstance(choices, list) and choices:
  452. first_choice = choices[0]
  453. if isinstance(first_choice, dict):
  454. message = first_choice.get("message")
  455. if isinstance(message, dict):
  456. content = message.get("content")
  457. if isinstance(content, str):
  458. return content
  459. text = first_choice.get("text")
  460. if isinstance(text, str):
  461. return text
  462. return ""
  463. def _extract_openai_stream_delta(payload: dict[str, JSONValue]) -> str:
  464. choices = payload.get("choices")
  465. if not isinstance(choices, list) or not choices:
  466. return ""
  467. first_choice = choices[0]
  468. if not isinstance(first_choice, dict):
  469. return ""
  470. delta = first_choice.get("delta")
  471. if isinstance(delta, dict):
  472. content = delta.get("content")
  473. if isinstance(content, str):
  474. return content
  475. text = first_choice.get("text")
  476. return text if isinstance(text, str) else ""
  477. def _extract_anthropic_stream_delta(payload: dict[str, JSONValue]) -> str:
  478. if payload.get("type") != "content_block_delta":
  479. return ""
  480. delta = payload.get("delta")
  481. if not isinstance(delta, dict):
  482. return ""
  483. text = delta.get("text")
  484. return text if isinstance(text, str) else ""
  485. def _extract_finish_reason(payload: dict[str, JSONValue]) -> str | None:
  486. choices = payload.get("choices")
  487. if isinstance(choices, list) and choices:
  488. first_choice = choices[0]
  489. if isinstance(first_choice, dict):
  490. finish_reason = first_choice.get("finish_reason")
  491. if isinstance(finish_reason, str):
  492. return finish_reason
  493. return None
  494. def _extract_tool_calls_json(payload: dict[str, JSONValue]) -> list[dict[str, JSONValue]]:
  495. choices = payload.get("choices")
  496. if isinstance(choices, list) and choices:
  497. first_choice = choices[0]
  498. if isinstance(first_choice, dict):
  499. message = first_choice.get("message")
  500. if isinstance(message, dict):
  501. tool_calls = message.get("tool_calls")
  502. if isinstance(tool_calls, list):
  503. return [
  504. {str(item_key): item_value for item_key, item_value in item.items()}
  505. for item in tool_calls
  506. if isinstance(item, dict)
  507. ]
  508. return []
  509. def _extract_usage_json(payload: dict[str, JSONValue]) -> dict[str, JSONValue]:
  510. usage = payload.get("usage")
  511. if isinstance(usage, dict):
  512. return {str(key): value for key, value in usage.items()}
  513. return {}