services.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  2. from app.bootstrap.settings import ModelGatewayServiceSettings
  3. from app.db.models import ModelDefinition, ModelProviderDefinition
  4. from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
  5. from app.infrastructure.provider import ModelProviderClient, ModelProviderClientError
  6. from app.schemas.model import (
  7. DiscoverModelsData,
  8. DiscoverModelsRequestDto,
  9. ModelCreateRequest,
  10. ModelCreateRequestDto,
  11. ModelDeleteRequestDto,
  12. ModelDto,
  13. ModelItemDto,
  14. ModelProviderCreateRequestDto,
  15. ModelProviderDeleteRequestDto,
  16. ModelProviderDto,
  17. ModelProviderTestData,
  18. ModelProviderTestRequestDto,
  19. ModelProviderUpdateRequestDto,
  20. ModelStatusUpdateRequest,
  21. ModelTestData,
  22. ModelTestRequest,
  23. ModelTestRequestDto,
  24. ModelTestResponse,
  25. ModelUpdateRequest,
  26. ModelUpdateRequestDto,
  27. _to_snake_model_item,
  28. )
  29. class ModelGatewayApplicationService:
  30. def __init__(
  31. self,
  32. *,
  33. model_repository: ModelDefinitionRepository,
  34. provider_repository: ModelProviderDefinitionRepository,
  35. provider_client: ModelProviderClient,
  36. settings: ModelGatewayServiceSettings) -> None:
  37. self.model_repository = model_repository
  38. self.provider_repository = provider_repository
  39. self.provider_client = provider_client
  40. self.settings = settings
  41. def create_model(self, payload: ModelCreateRequest) -> ModelDefinition:
  42. provider = self._get_provider_or_raise(payload.provider_id)
  43. if provider is not None:
  44. existing = self.model_repository.get_by_provider_model(
  45. provider_id=provider.id,
  46. model_name=payload.model_name)
  47. if existing is not None:
  48. existing.name = payload.name
  49. existing.provider_type = provider.provider_type
  50. existing.provider_base_url = provider.base_url
  51. existing.provider_api_key = provider.api_key
  52. existing.description = payload.description
  53. existing.capabilities_json = payload.capabilities_json
  54. existing.context_window = payload.context_window or existing.context_window
  55. existing.max_output_tokens = payload.max_output_tokens
  56. existing.default_temperature = payload.default_temperature
  57. existing.timeout_seconds = payload.timeout_seconds
  58. existing.metadata_json = {
  59. **(existing.metadata_json or {}),
  60. **payload.metadata_json,
  61. }
  62. return self.model_repository.update(existing)
  63. code = payload.code or self._build_model_code(payload.name, payload.model_name)
  64. if self.model_repository.get_by_code(code) is not None:
  65. raise ValueError(f"model code already exists: {code}")
  66. return self.model_repository.create(
  67. ModelDefinition(
  68. code=code,
  69. name=payload.name,
  70. provider_id=provider.id if provider is not None else None,
  71. provider_type=provider.provider_type if provider is not None else payload.provider_type,
  72. provider_base_url=provider.base_url if provider is not None else str(payload.provider_base_url),
  73. provider_api_key=provider.api_key if provider is not None else payload.provider_api_key,
  74. model_name=payload.model_name,
  75. status=payload.status,
  76. description=payload.description,
  77. capabilities_json=payload.capabilities_json,
  78. context_window=payload.context_window,
  79. max_output_tokens=payload.max_output_tokens,
  80. default_temperature=payload.default_temperature,
  81. timeout_seconds=payload.timeout_seconds,
  82. metadata_json=payload.metadata_json,
  83. )
  84. )
  85. def list_models(self) -> list[ModelDefinition]:
  86. return self.model_repository.list_all()
  87. def create_model_from_contract(self, payload: ModelCreateRequestDto) -> ModelDefinition:
  88. return self.create_model(
  89. ModelCreateRequest(
  90. name=payload.name,
  91. provider_id=payload.providerId,
  92. provider_type=payload.providerType,
  93. provider_base_url=payload.providerBaseUrl or "",
  94. provider_api_key=payload.providerApiKey,
  95. model_name=payload.modelName,
  96. description=payload.description,
  97. capabilities_json=payload.capabilities,
  98. context_window=payload.contextWindow,
  99. max_output_tokens=payload.maxOutputTokens,
  100. default_temperature=payload.defaultTemperature,
  101. timeout_seconds=payload.timeoutSeconds,
  102. metadata_json=payload.metadata))
  103. def update_model_from_contract(self, payload: ModelUpdateRequestDto) -> ModelDefinition | None:
  104. updates = payload.model_dump(exclude_unset=True)
  105. updates.pop("modelId", None)
  106. mapped_updates = {
  107. "name": updates.get("name"),
  108. "provider_id": updates.get("providerId"),
  109. "provider_type": updates.get("providerType"),
  110. "provider_base_url": updates.get("providerBaseUrl"),
  111. "provider_api_key": updates.get("providerApiKey"),
  112. "model_name": updates.get("modelName"),
  113. "description": updates.get("description"),
  114. "capabilities_json": updates.get("capabilities"),
  115. "context_window": updates.get("contextWindow"),
  116. "max_output_tokens": updates.get("maxOutputTokens"),
  117. "default_temperature": updates.get("defaultTemperature"),
  118. "timeout_seconds": updates.get("timeoutSeconds"),
  119. "metadata_json": updates.get("metadata"),
  120. }
  121. return self.update_model(
  122. model_id=payload.modelId,
  123. payload=ModelUpdateRequest(
  124. **{
  125. key: value
  126. for key, value in mapped_updates.items()
  127. if value is not None
  128. }))
  129. def update_model(self, model_id: str, payload: ModelUpdateRequest) -> ModelDefinition | None:
  130. entity = self.model_repository.get_by_id(model_id)
  131. if entity is None:
  132. return None
  133. updates = payload.model_dump(exclude_unset=True)
  134. if "code" in updates and updates["code"] != entity.code:
  135. existing = self.model_repository.get_by_code(str(updates["code"]))
  136. if existing is not None and existing.id != entity.id:
  137. raise ValueError(f"model code already exists: {updates['code']}")
  138. provider = self._get_provider_or_raise(updates.get("provider_id"))
  139. if provider is not None:
  140. updates["provider_id"] = provider.id
  141. updates["provider_type"] = provider.provider_type
  142. updates["provider_base_url"] = provider.base_url
  143. updates["provider_api_key"] = provider.api_key
  144. for key, value in updates.items():
  145. if key == "provider_base_url" and value is not None:
  146. value = str(value)
  147. setattr(entity, key, value)
  148. return self.model_repository.update(entity)
  149. def update_model_status(
  150. self,
  151. model_id: str,
  152. payload: ModelStatusUpdateRequest,
  153. ) -> ModelDefinition | None:
  154. entity = self.model_repository.get_by_id(model_id)
  155. if entity is None:
  156. return None
  157. entity.status = payload.status
  158. return self.model_repository.update(entity)
  159. def delete_model(self, model_id: str) -> bool:
  160. entity = self.model_repository.get_by_id(model_id)
  161. if entity is None:
  162. return False
  163. self.model_repository.delete(entity)
  164. return True
  165. def delete_model_from_contract(self, payload: ModelDeleteRequestDto) -> bool:
  166. return self.delete_model(payload.modelId)
  167. def create_chat_completion(
  168. self,
  169. payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
  170. configured_model = None
  171. if payload.model:
  172. configured_model = self.model_repository.get_active_for_request(payload.model)
  173. if configured_model is not None:
  174. configured_provider = self._resolve_model_provider(configured_model)
  175. resolved_payload = payload.model_copy(
  176. update={
  177. "model": configured_model.model_name,
  178. "temperature": payload.temperature
  179. if payload.temperature is not None
  180. else configured_model.default_temperature,
  181. "max_tokens": payload.max_tokens or configured_model.max_output_tokens,
  182. }
  183. )
  184. return self.provider_client.create_chat_completion(
  185. resolved_payload,
  186. provider_type=configured_provider.provider_type,
  187. provider_base_url=configured_provider.provider_base_url,
  188. provider_api_key=configured_provider.provider_api_key,
  189. timeout_seconds=configured_model.timeout_seconds,
  190. )
  191. resolved_payload = payload.model_copy(
  192. update={"model": payload.model or self.settings.default_model}
  193. )
  194. return self.provider_client.create_chat_completion(
  195. resolved_payload,
  196. provider_type=self.settings.provider_type)
  197. def test_model(self, model_id: str, payload: ModelTestRequest) -> ModelTestResponse | None:
  198. entity = self.model_repository.get_by_id(model_id)
  199. if entity is None:
  200. return None
  201. provider = self._resolve_model_provider(entity)
  202. messages = []
  203. if payload.system_prompt:
  204. messages.append({"role": "system", "content": payload.system_prompt})
  205. messages.append({"role": "user", "content": payload.prompt})
  206. response = self.provider_client.create_chat_completion(
  207. ChatCompletionRequestContract(
  208. model=entity.model_name,
  209. messages=messages,
  210. temperature=payload.temperature
  211. if payload.temperature is not None
  212. else entity.default_temperature,
  213. max_tokens=payload.max_tokens or entity.max_output_tokens,
  214. ),
  215. provider_type=provider.provider_type,
  216. provider_base_url=provider.provider_base_url,
  217. provider_api_key=provider.provider_api_key,
  218. timeout_seconds=entity.timeout_seconds,
  219. )
  220. from app.schemas.model import ModelResponse
  221. return ModelTestResponse(model=ModelResponse.from_entity(entity), response=response)
  222. def test_model_from_contract(self, payload: ModelTestRequestDto) -> ModelTestData | None:
  223. result = self.test_model(
  224. model_id=payload.modelId,
  225. payload=ModelTestRequest(
  226. prompt=payload.prompt,
  227. system_prompt=payload.systemPrompt,
  228. temperature=payload.temperature,
  229. max_tokens=payload.maxTokens))
  230. if result is None:
  231. return None
  232. entity = self.model_repository.get_by_id(payload.modelId)
  233. if entity is None:
  234. return None
  235. return ModelTestData(model=ModelDto.from_entity(entity), response=result.response)
  236. def list_providers(self) -> list[ModelProviderDefinition]:
  237. self._ensure_legacy_model_providers()
  238. return self.provider_repository.list_all()
  239. def create_provider(self, payload: ModelProviderCreateRequestDto) -> ModelProviderDefinition:
  240. provider = self.provider_repository.create(
  241. ModelProviderDefinition(
  242. name=payload.name,
  243. provider_type=payload.providerType,
  244. base_url=str(payload.baseUrl),
  245. api_key=payload.apiKey,
  246. models_json=[_to_snake_model_item(item) for item in payload.models],
  247. default_model=payload.defaultModel,
  248. extra_config_json=payload.extraConfig))
  249. self._refresh_and_sync_provider_models(provider, raise_on_empty=False)
  250. return provider
  251. def update_provider(
  252. self,
  253. payload: ModelProviderUpdateRequestDto) -> ModelProviderDefinition | None:
  254. entity = self.provider_repository.get_by_id(payload.providerId)
  255. if entity is None:
  256. return None
  257. updates = payload.model_dump(exclude_unset=True)
  258. updates.pop("providerId", None)
  259. for key, value in updates.items():
  260. if key == "baseUrl":
  261. entity.base_url = str(value) if value is not None else entity.base_url
  262. elif key == "apiKey":
  263. entity.api_key = value
  264. elif key == "defaultModel":
  265. entity.default_model = value
  266. elif key == "extraConfig":
  267. entity.extra_config_json = value
  268. elif key == "models":
  269. entity.models_json = [
  270. _to_snake_model_item(ModelItemDto(**item))
  271. if isinstance(item, dict)
  272. else _to_snake_model_item(item)
  273. for item in value or []
  274. ]
  275. elif key == "name":
  276. entity.name = value
  277. updated = self.provider_repository.update(entity)
  278. self._refresh_and_sync_provider_models(updated, raise_on_empty=False)
  279. return updated
  280. def delete_provider(self, payload: ModelProviderDeleteRequestDto) -> bool:
  281. entity = self.provider_repository.get_by_id(payload.providerId)
  282. if entity is None:
  283. return False
  284. self.provider_repository.delete(entity)
  285. return True
  286. def test_provider(self, payload: ModelProviderTestRequestDto) -> ModelProviderTestData | None:
  287. entity = self.provider_repository.get_by_id(payload.providerId)
  288. if entity is None:
  289. return None
  290. try:
  291. models = self.provider_client.list_models(
  292. provider_type=entity.provider_type,
  293. provider_base_url=entity.base_url,
  294. provider_api_key=entity.api_key)
  295. except ModelProviderClientError:
  296. models = list(entity.models_json or [])
  297. if not models:
  298. raise
  299. return ModelProviderTestData(
  300. success=True,
  301. message="Connection configuration is available.",
  302. latencyMs=0,
  303. modelList=[
  304. str(item.get("modelId") or item.get("model_id"))
  305. for item in models
  306. if item.get("modelId") or item.get("model_id")
  307. ])
  308. def discover_models(self, payload: DiscoverModelsRequestDto) -> DiscoverModelsData:
  309. provider_type = payload.providerType
  310. if payload.providerId:
  311. provider = self.provider_repository.get_by_id(payload.providerId)
  312. if provider is not None:
  313. discovered = self._refresh_and_sync_provider_models(
  314. provider,
  315. raise_on_empty=True)
  316. return DiscoverModelsData(
  317. providerType=provider.provider_type,
  318. models=discovered)
  319. if payload.baseUrl:
  320. discovered = [
  321. ModelItemDto(**item)
  322. for item in self.provider_client.list_models(
  323. provider_type=provider_type,
  324. provider_base_url=str(payload.baseUrl),
  325. provider_api_key=payload.apiKey)
  326. ]
  327. return DiscoverModelsData(
  328. providerType=provider_type or self.settings.provider_type,
  329. models=discovered)
  330. return DiscoverModelsData(
  331. providerType=provider_type or self.settings.provider_type,
  332. models=self._default_model_catalog(provider_type or self.settings.provider_type))
  333. def _default_model_catalog(self, provider_type: str) -> list[ModelItemDto]:
  334. catalogs = {
  335. "openai": [
  336. ModelItemDto(
  337. modelId="gpt-4.1-mini",
  338. displayName="GPT-4.1 Mini",
  339. modelType="chat",
  340. ownedBy="openai",
  341. contextWindow=1047576),
  342. ModelItemDto(
  343. modelId="text-embedding-3-small",
  344. displayName="Text Embedding 3 Small",
  345. modelType="embedding",
  346. ownedBy="openai"),
  347. ],
  348. "openai_compatible": [
  349. ModelItemDto(
  350. modelId="gpt-4.1-mini",
  351. displayName="OpenAI Compatible Chat",
  352. modelType="chat",
  353. contextWindow=128000),
  354. ModelItemDto(
  355. modelId="text-embedding-3-small",
  356. displayName="OpenAI Compatible Embedding",
  357. modelType="embedding"),
  358. ],
  359. "ollama": [
  360. ModelItemDto(
  361. modelId="llama3.1:8b",
  362. displayName="LLaMA 3.1 8B",
  363. modelType="chat",
  364. ownedBy="meta",
  365. contextWindow=131072),
  366. ModelItemDto(
  367. modelId="nomic-embed-text",
  368. displayName="Nomic Embed Text",
  369. modelType="embedding",
  370. ownedBy="nomic"),
  371. ],
  372. "deepseek": [
  373. ModelItemDto(
  374. modelId="deepseek-chat",
  375. displayName="DeepSeek Chat",
  376. modelType="chat",
  377. ownedBy="deepseek",
  378. contextWindow=64000),
  379. ModelItemDto(
  380. modelId="deepseek-reasoner",
  381. displayName="DeepSeek Reasoner",
  382. modelType="reasoning",
  383. ownedBy="deepseek",
  384. contextWindow=64000),
  385. ],
  386. }
  387. return catalogs.get(provider_type, [])
  388. def _build_model_code(self, name: str, model_name: str) -> str:
  389. base = "".join(
  390. char.lower() if char.isalnum() else "_"
  391. for char in f"{name}_{model_name}"
  392. ).strip("_") or "model"
  393. candidate = base[:64]
  394. suffix = 1
  395. while self.model_repository.get_by_code(candidate) is not None:
  396. suffix_text = f"_{suffix}"
  397. candidate = f"{base[:64 - len(suffix_text)]}{suffix_text}"
  398. suffix += 1
  399. return candidate
  400. def _get_provider_or_raise(self, provider_id: str | None) -> ModelProviderDefinition | None:
  401. if provider_id is None:
  402. return None
  403. provider = self.provider_repository.get_by_id(provider_id)
  404. if provider is None:
  405. raise ValueError(f"model provider not found: {provider_id}")
  406. return provider
  407. def _resolve_model_provider(self, model: ModelDefinition) -> "_ResolvedModelProvider":
  408. provider = self.provider_repository.get_by_id(model.provider_id) if model.provider_id else None
  409. if provider is not None:
  410. return _ResolvedModelProvider(
  411. provider_type=provider.provider_type,
  412. provider_base_url=provider.base_url,
  413. provider_api_key=provider.api_key)
  414. return _ResolvedModelProvider(
  415. provider_type=model.provider_type,
  416. provider_base_url=model.provider_base_url,
  417. provider_api_key=model.provider_api_key)
  418. def _ensure_legacy_model_providers(self) -> None:
  419. legacy_models = [
  420. model
  421. for model in self.model_repository.list_all()
  422. if model.provider_id is None and model.provider_base_url
  423. ]
  424. for model in legacy_models:
  425. provider = self.provider_repository.get_by_connection(
  426. provider_type=model.provider_type,
  427. base_url=model.provider_base_url)
  428. if provider is None:
  429. provider = self.provider_repository.create(
  430. ModelProviderDefinition(
  431. name=self._build_provider_name(model.provider_type, model.provider_base_url),
  432. provider_type=model.provider_type,
  433. base_url=model.provider_base_url,
  434. api_key=model.provider_api_key,
  435. models_json=[],
  436. default_model=model.model_name,
  437. extra_config_json={"source": "legacy_model_backfill"}))
  438. model.provider_id = provider.id
  439. self._append_provider_model(provider=provider, model=model)
  440. self.model_repository.update(model)
  441. self.provider_repository.update(provider)
  442. def _append_provider_model(
  443. self,
  444. *,
  445. provider: ModelProviderDefinition,
  446. model: ModelDefinition) -> None:
  447. existing_items = list(provider.models_json or [])
  448. if any(item.get("model_id") == model.model_name for item in existing_items):
  449. return
  450. existing_items.append(
  451. {
  452. "model_id": model.model_name,
  453. "display_name": model.name,
  454. "model_type": "chat",
  455. })
  456. provider.models_json = existing_items
  457. def _sync_provider_models(
  458. self,
  459. *,
  460. provider: ModelProviderDefinition,
  461. models: list[ModelItemDto]) -> None:
  462. for item in models:
  463. model_name = item.modelId.strip()
  464. if not model_name:
  465. continue
  466. existing = self.model_repository.get_by_provider_model(
  467. provider_id=provider.id,
  468. model_name=model_name)
  469. capabilities = self._capabilities_for_model_item(item, provider.provider_type)
  470. if existing is None:
  471. self.model_repository.create(
  472. ModelDefinition(
  473. code=self._build_model_code(item.displayName or model_name, model_name),
  474. name=item.displayName or model_name,
  475. provider_id=provider.id,
  476. provider_type=provider.provider_type,
  477. provider_base_url=provider.base_url,
  478. provider_api_key=provider.api_key,
  479. model_name=model_name,
  480. status="active",
  481. description=None,
  482. capabilities_json=capabilities,
  483. context_window=item.contextWindow,
  484. max_output_tokens=None,
  485. default_temperature=None,
  486. timeout_seconds=60.0,
  487. metadata_json={"source": "provider_discovery"},
  488. )
  489. )
  490. continue
  491. existing.name = item.displayName or existing.name
  492. existing.provider_type = provider.provider_type
  493. existing.provider_base_url = provider.base_url
  494. existing.provider_api_key = provider.api_key
  495. existing.capabilities_json = capabilities
  496. existing.context_window = item.contextWindow or existing.context_window
  497. existing.metadata_json = {
  498. **(existing.metadata_json or {}),
  499. "source": "provider_discovery",
  500. }
  501. self.model_repository.update(existing)
  502. def _refresh_and_sync_provider_models(
  503. self,
  504. provider: ModelProviderDefinition,
  505. *,
  506. raise_on_empty: bool) -> list[ModelItemDto]:
  507. try:
  508. discovered = [
  509. ModelItemDto(**item)
  510. for item in self.provider_client.list_models(
  511. provider_type=provider.provider_type,
  512. provider_base_url=provider.base_url,
  513. provider_api_key=provider.api_key)
  514. ]
  515. except ModelProviderClientError:
  516. discovered = ModelProviderDto.from_entity(provider).models
  517. if not discovered and provider.provider_type == "deepseek":
  518. discovered = self._default_model_catalog("deepseek")
  519. if not discovered and raise_on_empty:
  520. raise
  521. if not discovered:
  522. return []
  523. provider.models_json = [_to_snake_model_item(item) for item in discovered]
  524. if provider.default_model is None:
  525. provider.default_model = discovered[0].modelId
  526. self.provider_repository.update(provider)
  527. self._sync_provider_models(provider=provider, models=discovered)
  528. return discovered
  529. def _capabilities_for_model_item(
  530. self,
  531. item: ModelItemDto,
  532. provider_type: str) -> list[str]:
  533. model_type = item.modelType
  534. capabilities: set[str] = set()
  535. if model_type == "reasoning":
  536. capabilities.update(["chat", "reasoning"])
  537. elif model_type in {"embedding", "image", "audio", "video", "rerank", "moderation"}:
  538. capabilities.add(model_type)
  539. else:
  540. capabilities.add("chat")
  541. if provider_type in {"openai", "anthropic", "deepseek", "openai_compatible"} and "chat" in capabilities:
  542. capabilities.add("tools")
  543. return sorted(capabilities)
  544. def _build_provider_name(self, provider_type: str, base_url: str) -> str:
  545. label = provider_type.replace("_", " ").title()
  546. host = base_url.split("//")[-1].split("/")[0]
  547. return f"{label} - {host}" if host else label
  548. class _ResolvedModelProvider:
  549. def __init__(
  550. self,
  551. *,
  552. provider_type: str,
  553. provider_base_url: str,
  554. provider_api_key: str | None) -> None:
  555. self.provider_type = provider_type
  556. self.provider_base_url = provider_base_url
  557. self.provider_api_key = provider_api_key