services.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. from core_domain import (
  2. ChatCompletionRequestContract,
  3. ChatCompletionResponseContract,
  4. EmbeddingRequestContract,
  5. EmbeddingResponseContract,
  6. )
  7. from collections.abc import Iterator
  8. from app.bootstrap.settings import ModelGatewayServiceSettings
  9. from app.db.models import ModelDefinition, ModelProviderDefinition
  10. from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
  11. from app.infrastructure.provider import ModelProviderClient, ModelProviderClientError
  12. from app.schemas.model import (
  13. DiscoverModelsData,
  14. DiscoverModelsRequestDto,
  15. ModelCreateRequest,
  16. ModelCreateRequestDto,
  17. ModelDeleteRequestDto,
  18. ModelDto,
  19. ModelItemDto,
  20. ModelProviderCreateRequestDto,
  21. ModelProviderDeleteRequestDto,
  22. ModelProviderDto,
  23. ModelProviderTestData,
  24. ModelProviderTestRequestDto,
  25. ModelProviderUpdateRequestDto,
  26. ModelStatusUpdateRequest,
  27. ModelTestData,
  28. ModelTestRequest,
  29. ModelTestRequestDto,
  30. ModelTestResponse,
  31. ModelUpdateRequest,
  32. ModelUpdateRequestDto,
  33. _to_snake_model_item,
  34. )
  35. class ModelGatewayApplicationService:
  36. def __init__(
  37. self,
  38. *,
  39. model_repository: ModelDefinitionRepository,
  40. provider_repository: ModelProviderDefinitionRepository,
  41. provider_client: ModelProviderClient,
  42. settings: ModelGatewayServiceSettings) -> None:
  43. self.model_repository = model_repository
  44. self.provider_repository = provider_repository
  45. self.provider_client = provider_client
  46. self.settings = settings
  47. def create_model(self, payload: ModelCreateRequest) -> ModelDefinition:
  48. provider = self._get_provider_or_raise(payload.provider_id)
  49. if provider is not None:
  50. existing = self.model_repository.get_by_provider_model(
  51. provider_id=provider.id,
  52. model_name=payload.model_name)
  53. if existing is not None:
  54. existing.name = payload.name
  55. existing.provider_type = provider.provider_type
  56. existing.provider_base_url = provider.base_url
  57. existing.provider_api_key = provider.api_key
  58. existing.description = payload.description
  59. existing.capabilities_json = payload.capabilities_json
  60. existing.context_window = payload.context_window or existing.context_window
  61. existing.max_output_tokens = payload.max_output_tokens
  62. existing.default_temperature = payload.default_temperature
  63. existing.timeout_seconds = payload.timeout_seconds
  64. existing.metadata_json = {
  65. **(existing.metadata_json or {}),
  66. **payload.metadata_json,
  67. }
  68. return self.model_repository.update(existing)
  69. code = payload.code or self._build_model_code(payload.name, payload.model_name)
  70. if self.model_repository.get_by_code(code) is not None:
  71. raise ValueError(f"model code already exists: {code}")
  72. return self.model_repository.create(
  73. ModelDefinition(
  74. code=code,
  75. name=payload.name,
  76. provider_id=provider.id if provider is not None else None,
  77. provider_type=provider.provider_type if provider is not None else payload.provider_type,
  78. provider_base_url=provider.base_url if provider is not None else str(payload.provider_base_url),
  79. provider_api_key=provider.api_key if provider is not None else payload.provider_api_key,
  80. model_name=payload.model_name,
  81. status=payload.status,
  82. description=payload.description,
  83. capabilities_json=payload.capabilities_json,
  84. context_window=payload.context_window,
  85. max_output_tokens=payload.max_output_tokens,
  86. default_temperature=payload.default_temperature,
  87. timeout_seconds=payload.timeout_seconds,
  88. metadata_json=payload.metadata_json,
  89. )
  90. )
  91. def list_models(self) -> list[ModelDefinition]:
  92. return self.model_repository.list_all()
  93. def create_model_from_contract(self, payload: ModelCreateRequestDto) -> ModelDefinition:
  94. return self.create_model(
  95. ModelCreateRequest(
  96. name=payload.name,
  97. provider_id=payload.providerId,
  98. provider_type=payload.providerType,
  99. provider_base_url=payload.providerBaseUrl or "",
  100. provider_api_key=payload.providerApiKey,
  101. model_name=payload.modelName,
  102. description=payload.description,
  103. capabilities_json=payload.capabilities,
  104. context_window=payload.contextWindow,
  105. max_output_tokens=payload.maxOutputTokens,
  106. default_temperature=payload.defaultTemperature,
  107. timeout_seconds=payload.timeoutSeconds,
  108. metadata_json=payload.metadata))
  109. def update_model_from_contract(self, payload: ModelUpdateRequestDto) -> ModelDefinition | None:
  110. updates = payload.model_dump(exclude_unset=True)
  111. updates.pop("modelId", None)
  112. mapped_updates = {
  113. "name": updates.get("name"),
  114. "provider_id": updates.get("providerId"),
  115. "provider_type": updates.get("providerType"),
  116. "provider_base_url": updates.get("providerBaseUrl"),
  117. "provider_api_key": updates.get("providerApiKey"),
  118. "model_name": updates.get("modelName"),
  119. "description": updates.get("description"),
  120. "capabilities_json": updates.get("capabilities"),
  121. "context_window": updates.get("contextWindow"),
  122. "max_output_tokens": updates.get("maxOutputTokens"),
  123. "default_temperature": updates.get("defaultTemperature"),
  124. "timeout_seconds": updates.get("timeoutSeconds"),
  125. "metadata_json": updates.get("metadata"),
  126. }
  127. return self.update_model(
  128. model_id=payload.modelId,
  129. payload=ModelUpdateRequest(
  130. **{
  131. key: value
  132. for key, value in mapped_updates.items()
  133. if value is not None
  134. }))
  135. def update_model(self, model_id: str, payload: ModelUpdateRequest) -> ModelDefinition | None:
  136. entity = self.model_repository.get_by_id(model_id)
  137. if entity is None:
  138. return None
  139. updates = payload.model_dump(exclude_unset=True)
  140. if "code" in updates and updates["code"] != entity.code:
  141. existing = self.model_repository.get_by_code(str(updates["code"]))
  142. if existing is not None and existing.id != entity.id:
  143. raise ValueError(f"model code already exists: {updates['code']}")
  144. provider = self._get_provider_or_raise(updates.get("provider_id"))
  145. if provider is not None:
  146. updates["provider_id"] = provider.id
  147. updates["provider_type"] = provider.provider_type
  148. updates["provider_base_url"] = provider.base_url
  149. updates["provider_api_key"] = provider.api_key
  150. for key, value in updates.items():
  151. if key == "provider_base_url" and value is not None:
  152. value = str(value)
  153. setattr(entity, key, value)
  154. return self.model_repository.update(entity)
  155. def update_model_status(
  156. self,
  157. model_id: str,
  158. payload: ModelStatusUpdateRequest,
  159. ) -> ModelDefinition | None:
  160. entity = self.model_repository.get_by_id(model_id)
  161. if entity is None:
  162. return None
  163. entity.status = payload.status
  164. return self.model_repository.update(entity)
  165. def delete_model(self, model_id: str) -> bool:
  166. entity = self.model_repository.get_by_id(model_id)
  167. if entity is None:
  168. return False
  169. self.model_repository.delete(entity)
  170. return True
  171. def delete_model_from_contract(self, payload: ModelDeleteRequestDto) -> bool:
  172. return self.delete_model(payload.modelId)
  173. def create_chat_completion(
  174. self,
  175. payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
  176. configured_model = None
  177. if payload.model:
  178. configured_model = self.model_repository.get_active_for_request(payload.model)
  179. if configured_model is not None:
  180. configured_provider = self._resolve_model_provider(configured_model)
  181. resolved_payload = payload.model_copy(
  182. update={
  183. "model": configured_model.model_name,
  184. "temperature": payload.temperature
  185. if payload.temperature is not None
  186. else configured_model.default_temperature,
  187. "max_tokens": payload.max_tokens or configured_model.max_output_tokens,
  188. }
  189. )
  190. return self.provider_client.create_chat_completion(
  191. resolved_payload,
  192. provider_type=configured_provider.provider_type,
  193. provider_base_url=configured_provider.provider_base_url,
  194. provider_api_key=configured_provider.provider_api_key,
  195. timeout_seconds=configured_model.timeout_seconds,
  196. )
  197. resolved_payload = payload.model_copy(
  198. update={"model": payload.model or self.settings.default_model}
  199. )
  200. return self.provider_client.create_chat_completion(
  201. resolved_payload,
  202. provider_type=self.settings.provider_type)
  203. def create_embedding(
  204. self,
  205. payload: EmbeddingRequestContract) -> EmbeddingResponseContract:
  206. configured_model = None
  207. if payload.model:
  208. configured_model = self.model_repository.get_active_for_request(payload.model)
  209. if configured_model is not None:
  210. configured_provider = self._resolve_model_provider(configured_model)
  211. resolved_payload = payload.model_copy(
  212. update={"model": configured_model.model_name}
  213. )
  214. return self.provider_client.create_embedding(
  215. resolved_payload,
  216. provider_type=configured_provider.provider_type,
  217. provider_base_url=configured_provider.provider_base_url,
  218. provider_api_key=configured_provider.provider_api_key,
  219. timeout_seconds=configured_model.timeout_seconds,
  220. )
  221. resolved_payload = payload.model_copy(
  222. update={"model": payload.model or self.settings.default_model}
  223. )
  224. return self.provider_client.create_embedding(
  225. resolved_payload,
  226. provider_type=self.settings.provider_type)
  227. def stream_chat_completion(
  228. self,
  229. payload: ChatCompletionRequestContract) -> Iterator[str]:
  230. configured_model = None
  231. if payload.model:
  232. configured_model = self.model_repository.get_active_for_request(payload.model)
  233. if configured_model is not None:
  234. configured_provider = self._resolve_model_provider(configured_model)
  235. resolved_payload = payload.model_copy(
  236. update={
  237. "model": configured_model.model_name,
  238. "temperature": payload.temperature
  239. if payload.temperature is not None
  240. else configured_model.default_temperature,
  241. "max_tokens": payload.max_tokens or configured_model.max_output_tokens,
  242. }
  243. )
  244. yield from self.provider_client.stream_chat_completion(
  245. resolved_payload,
  246. provider_type=configured_provider.provider_type,
  247. provider_base_url=configured_provider.provider_base_url,
  248. provider_api_key=configured_provider.provider_api_key,
  249. timeout_seconds=configured_model.timeout_seconds,
  250. )
  251. return
  252. resolved_payload = payload.model_copy(
  253. update={"model": payload.model or self.settings.default_model}
  254. )
  255. yield from self.provider_client.stream_chat_completion(
  256. resolved_payload,
  257. provider_type=self.settings.provider_type)
  258. def test_model(self, model_id: str, payload: ModelTestRequest) -> ModelTestResponse | None:
  259. entity = self.model_repository.get_by_id(model_id)
  260. if entity is None:
  261. return None
  262. provider = self._resolve_model_provider(entity)
  263. messages = []
  264. if payload.system_prompt:
  265. messages.append({"role": "system", "content": payload.system_prompt})
  266. messages.append({"role": "user", "content": payload.prompt})
  267. response = self.provider_client.create_chat_completion(
  268. ChatCompletionRequestContract(
  269. model=entity.model_name,
  270. messages=messages,
  271. temperature=payload.temperature
  272. if payload.temperature is not None
  273. else entity.default_temperature,
  274. max_tokens=payload.max_tokens or entity.max_output_tokens,
  275. ),
  276. provider_type=provider.provider_type,
  277. provider_base_url=provider.provider_base_url,
  278. provider_api_key=provider.provider_api_key,
  279. timeout_seconds=entity.timeout_seconds,
  280. )
  281. from app.schemas.model import ModelResponse
  282. return ModelTestResponse(model=ModelResponse.from_entity(entity), response=response)
  283. def test_model_from_contract(self, payload: ModelTestRequestDto) -> ModelTestData | None:
  284. result = self.test_model(
  285. model_id=payload.modelId,
  286. payload=ModelTestRequest(
  287. prompt=payload.prompt,
  288. system_prompt=payload.systemPrompt,
  289. temperature=payload.temperature,
  290. max_tokens=payload.maxTokens))
  291. if result is None:
  292. return None
  293. entity = self.model_repository.get_by_id(payload.modelId)
  294. if entity is None:
  295. return None
  296. return ModelTestData(model=ModelDto.from_entity(entity), response=result.response)
  297. def list_providers(self) -> list[ModelProviderDefinition]:
  298. self._ensure_legacy_model_providers()
  299. return self.provider_repository.list_all()
  300. def create_provider(self, payload: ModelProviderCreateRequestDto) -> ModelProviderDefinition:
  301. provider = self.provider_repository.create(
  302. ModelProviderDefinition(
  303. name=payload.name,
  304. provider_type=payload.providerType,
  305. base_url=str(payload.baseUrl),
  306. api_key=payload.apiKey,
  307. models_json=[_to_snake_model_item(item) for item in payload.models],
  308. default_model=payload.defaultModel,
  309. extra_config_json=payload.extraConfig))
  310. self._refresh_and_sync_provider_models(provider, raise_on_empty=False)
  311. return provider
  312. def update_provider(
  313. self,
  314. payload: ModelProviderUpdateRequestDto) -> ModelProviderDefinition | None:
  315. entity = self.provider_repository.get_by_id(payload.providerId)
  316. if entity is None:
  317. return None
  318. updates = payload.model_dump(exclude_unset=True)
  319. updates.pop("providerId", None)
  320. for key, value in updates.items():
  321. if key == "baseUrl":
  322. entity.base_url = str(value) if value is not None else entity.base_url
  323. elif key == "apiKey":
  324. entity.api_key = value
  325. elif key == "defaultModel":
  326. entity.default_model = value
  327. elif key == "extraConfig":
  328. entity.extra_config_json = value
  329. elif key == "models":
  330. entity.models_json = [
  331. _to_snake_model_item(ModelItemDto(**item))
  332. if isinstance(item, dict)
  333. else _to_snake_model_item(item)
  334. for item in value or []
  335. ]
  336. elif key == "name":
  337. entity.name = value
  338. updated = self.provider_repository.update(entity)
  339. self._refresh_and_sync_provider_models(updated, raise_on_empty=False)
  340. return updated
  341. def delete_provider(self, payload: ModelProviderDeleteRequestDto) -> bool:
  342. entity = self.provider_repository.get_by_id(payload.providerId)
  343. if entity is None:
  344. return False
  345. self.provider_repository.delete(entity)
  346. return True
  347. def test_provider(self, payload: ModelProviderTestRequestDto) -> ModelProviderTestData | None:
  348. entity = self.provider_repository.get_by_id(payload.providerId)
  349. if entity is None:
  350. return None
  351. try:
  352. models = self.provider_client.list_models(
  353. provider_type=entity.provider_type,
  354. provider_base_url=entity.base_url,
  355. provider_api_key=entity.api_key)
  356. except ModelProviderClientError:
  357. models = list(entity.models_json or [])
  358. if not models:
  359. raise
  360. return ModelProviderTestData(
  361. success=True,
  362. message="Connection configuration is available.",
  363. latencyMs=0,
  364. modelList=[
  365. str(item.get("modelId") or item.get("model_id"))
  366. for item in models
  367. if item.get("modelId") or item.get("model_id")
  368. ])
  369. def discover_models(self, payload: DiscoverModelsRequestDto) -> DiscoverModelsData:
  370. provider_type = payload.providerType
  371. if payload.providerId:
  372. provider = self.provider_repository.get_by_id(payload.providerId)
  373. if provider is not None:
  374. discovered = self._refresh_and_sync_provider_models(
  375. provider,
  376. raise_on_empty=True)
  377. return DiscoverModelsData(
  378. providerType=provider.provider_type,
  379. models=discovered)
  380. if payload.baseUrl:
  381. discovered = [
  382. ModelItemDto(**item)
  383. for item in self.provider_client.list_models(
  384. provider_type=provider_type,
  385. provider_base_url=str(payload.baseUrl),
  386. provider_api_key=payload.apiKey)
  387. ]
  388. return DiscoverModelsData(
  389. providerType=provider_type or self.settings.provider_type,
  390. models=discovered)
  391. return DiscoverModelsData(
  392. providerType=provider_type or self.settings.provider_type,
  393. models=self._default_model_catalog(provider_type or self.settings.provider_type))
  394. def _default_model_catalog(self, provider_type: str) -> list[ModelItemDto]:
  395. catalogs = {
  396. "openai": [
  397. ModelItemDto(
  398. modelId="gpt-4.1-mini",
  399. displayName="GPT-4.1 Mini",
  400. modelType="chat",
  401. ownedBy="openai",
  402. contextWindow=1047576),
  403. ModelItemDto(
  404. modelId="text-embedding-3-small",
  405. displayName="Text Embedding 3 Small",
  406. modelType="embedding",
  407. ownedBy="openai"),
  408. ],
  409. "openai_compatible": [
  410. ModelItemDto(
  411. modelId="gpt-4.1-mini",
  412. displayName="OpenAI Compatible Chat",
  413. modelType="chat",
  414. contextWindow=128000),
  415. ModelItemDto(
  416. modelId="text-embedding-3-small",
  417. displayName="OpenAI Compatible Embedding",
  418. modelType="embedding"),
  419. ],
  420. "ollama": [
  421. ModelItemDto(
  422. modelId="llama3.1:8b",
  423. displayName="LLaMA 3.1 8B",
  424. modelType="chat",
  425. ownedBy="meta",
  426. contextWindow=131072),
  427. ModelItemDto(
  428. modelId="nomic-embed-text",
  429. displayName="Nomic Embed Text",
  430. modelType="embedding",
  431. ownedBy="nomic"),
  432. ],
  433. "deepseek": [
  434. ModelItemDto(
  435. modelId="deepseek-chat",
  436. displayName="DeepSeek Chat",
  437. modelType="chat",
  438. ownedBy="deepseek",
  439. contextWindow=64000),
  440. ModelItemDto(
  441. modelId="deepseek-reasoner",
  442. displayName="DeepSeek Reasoner",
  443. modelType="reasoning",
  444. ownedBy="deepseek",
  445. contextWindow=64000),
  446. ],
  447. }
  448. return catalogs.get(provider_type, [])
  449. def _build_model_code(self, name: str, model_name: str) -> str:
  450. base = "".join(
  451. char.lower() if char.isalnum() else "_"
  452. for char in f"{name}_{model_name}"
  453. ).strip("_") or "model"
  454. candidate = base[:64]
  455. suffix = 1
  456. while self.model_repository.get_by_code(candidate) is not None:
  457. suffix_text = f"_{suffix}"
  458. candidate = f"{base[:64 - len(suffix_text)]}{suffix_text}"
  459. suffix += 1
  460. return candidate
  461. def _get_provider_or_raise(self, provider_id: str | None) -> ModelProviderDefinition | None:
  462. if provider_id is None:
  463. return None
  464. provider = self.provider_repository.get_by_id(provider_id)
  465. if provider is None:
  466. raise ValueError(f"model provider not found: {provider_id}")
  467. return provider
  468. def _resolve_model_provider(self, model: ModelDefinition) -> "_ResolvedModelProvider":
  469. provider = self.provider_repository.get_by_id(model.provider_id) if model.provider_id else None
  470. if provider is not None:
  471. return _ResolvedModelProvider(
  472. provider_type=provider.provider_type,
  473. provider_base_url=provider.base_url,
  474. provider_api_key=provider.api_key)
  475. return _ResolvedModelProvider(
  476. provider_type=model.provider_type,
  477. provider_base_url=model.provider_base_url,
  478. provider_api_key=model.provider_api_key)
  479. def _ensure_legacy_model_providers(self) -> None:
  480. legacy_models = [
  481. model
  482. for model in self.model_repository.list_all()
  483. if model.provider_id is None and model.provider_base_url
  484. ]
  485. for model in legacy_models:
  486. provider = self.provider_repository.get_by_connection(
  487. provider_type=model.provider_type,
  488. base_url=model.provider_base_url)
  489. if provider is None:
  490. provider = self.provider_repository.create(
  491. ModelProviderDefinition(
  492. name=self._build_provider_name(model.provider_type, model.provider_base_url),
  493. provider_type=model.provider_type,
  494. base_url=model.provider_base_url,
  495. api_key=model.provider_api_key,
  496. models_json=[],
  497. default_model=model.model_name,
  498. extra_config_json={"source": "legacy_model_backfill"}))
  499. model.provider_id = provider.id
  500. self._append_provider_model(provider=provider, model=model)
  501. self.model_repository.update(model)
  502. self.provider_repository.update(provider)
  503. def _append_provider_model(
  504. self,
  505. *,
  506. provider: ModelProviderDefinition,
  507. model: ModelDefinition) -> None:
  508. existing_items = list(provider.models_json or [])
  509. if any(item.get("model_id") == model.model_name for item in existing_items):
  510. return
  511. existing_items.append(
  512. {
  513. "model_id": model.model_name,
  514. "display_name": model.name,
  515. "model_type": "chat",
  516. })
  517. provider.models_json = existing_items
  518. def _sync_provider_models(
  519. self,
  520. *,
  521. provider: ModelProviderDefinition,
  522. models: list[ModelItemDto]) -> None:
  523. for item in models:
  524. model_name = item.modelId.strip()
  525. if not model_name:
  526. continue
  527. existing = self.model_repository.get_by_provider_model(
  528. provider_id=provider.id,
  529. model_name=model_name)
  530. capabilities = self._capabilities_for_model_item(item, provider.provider_type)
  531. if existing is None:
  532. self.model_repository.create(
  533. ModelDefinition(
  534. code=self._build_model_code(item.displayName or model_name, model_name),
  535. name=item.displayName or model_name,
  536. provider_id=provider.id,
  537. provider_type=provider.provider_type,
  538. provider_base_url=provider.base_url,
  539. provider_api_key=provider.api_key,
  540. model_name=model_name,
  541. status="active",
  542. description=None,
  543. capabilities_json=capabilities,
  544. context_window=item.contextWindow,
  545. max_output_tokens=None,
  546. default_temperature=None,
  547. timeout_seconds=60.0,
  548. metadata_json={"source": "provider_discovery"},
  549. )
  550. )
  551. continue
  552. existing.name = item.displayName or existing.name
  553. existing.provider_type = provider.provider_type
  554. existing.provider_base_url = provider.base_url
  555. existing.provider_api_key = provider.api_key
  556. existing.capabilities_json = capabilities
  557. existing.context_window = item.contextWindow or existing.context_window
  558. existing.metadata_json = {
  559. **(existing.metadata_json or {}),
  560. "source": "provider_discovery",
  561. }
  562. self.model_repository.update(existing)
  563. def _refresh_and_sync_provider_models(
  564. self,
  565. provider: ModelProviderDefinition,
  566. *,
  567. raise_on_empty: bool) -> list[ModelItemDto]:
  568. try:
  569. discovered = [
  570. ModelItemDto(**item)
  571. for item in self.provider_client.list_models(
  572. provider_type=provider.provider_type,
  573. provider_base_url=provider.base_url,
  574. provider_api_key=provider.api_key)
  575. ]
  576. except ModelProviderClientError:
  577. discovered = ModelProviderDto.from_entity(provider).models
  578. if not discovered and provider.provider_type == "deepseek":
  579. discovered = self._default_model_catalog("deepseek")
  580. if not discovered and raise_on_empty:
  581. raise
  582. if not discovered:
  583. return []
  584. provider.models_json = [_to_snake_model_item(item) for item in discovered]
  585. if provider.default_model is None:
  586. provider.default_model = discovered[0].modelId
  587. self.provider_repository.update(provider)
  588. self._sync_provider_models(provider=provider, models=discovered)
  589. return discovered
  590. def _capabilities_for_model_item(
  591. self,
  592. item: ModelItemDto,
  593. provider_type: str) -> list[str]:
  594. model_type = item.modelType
  595. capabilities: set[str] = set()
  596. if model_type == "reasoning":
  597. capabilities.update(["chat", "reasoning"])
  598. elif model_type in {"embedding", "image", "audio", "video", "rerank", "moderation"}:
  599. capabilities.add(model_type)
  600. else:
  601. capabilities.add("chat")
  602. if provider_type in {"openai", "anthropic", "deepseek", "openai_compatible"} and "chat" in capabilities:
  603. capabilities.add("tools")
  604. return sorted(capabilities)
  605. def _build_provider_name(self, provider_type: str, base_url: str) -> str:
  606. label = provider_type.replace("_", " ").title()
  607. host = base_url.split("//")[-1].split("/")[0]
  608. return f"{label} - {host}" if host else label
  609. class _ResolvedModelProvider:
  610. def __init__(
  611. self,
  612. *,
  613. provider_type: str,
  614. provider_base_url: str,
  615. provider_api_key: str | None) -> None:
  616. self.provider_type = provider_type
  617. self.provider_base_url = provider_base_url
  618. self.provider_api_key = provider_api_key