services.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  2. from app.bootstrap.settings import ModelGatewayServiceSettings
  3. from app.db.models import ModelDefinition, ModelProviderDefinition
  4. from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
  5. from app.infrastructure.provider import ModelProviderClient
  6. from app.schemas.model import (
  7. DiscoverModelsData,
  8. DiscoverModelsRequestDto,
  9. ModelCreateRequest,
  10. ModelCreateRequestDto,
  11. ModelDeleteRequestDto,
  12. ModelDto,
  13. ModelItemDto,
  14. ModelProviderCreateRequestDto,
  15. ModelProviderDeleteRequestDto,
  16. ModelProviderDto,
  17. ModelProviderTestData,
  18. ModelProviderTestRequestDto,
  19. ModelProviderUpdateRequestDto,
  20. ModelStatusUpdateRequest,
  21. ModelTestData,
  22. ModelTestRequest,
  23. ModelTestRequestDto,
  24. ModelTestResponse,
  25. ModelUpdateRequest,
  26. ModelUpdateRequestDto,
  27. _to_snake_model_item,
  28. )
  29. class ModelGatewayApplicationService:
  30. def __init__(
  31. self,
  32. *,
  33. model_repository: ModelDefinitionRepository,
  34. provider_repository: ModelProviderDefinitionRepository,
  35. provider_client: ModelProviderClient,
  36. settings: ModelGatewayServiceSettings) -> None:
  37. self.model_repository = model_repository
  38. self.provider_repository = provider_repository
  39. self.provider_client = provider_client
  40. self.settings = settings
  41. def create_model(self, payload: ModelCreateRequest) -> ModelDefinition:
  42. provider = self._get_provider_or_raise(payload.provider_id)
  43. code = payload.code or self._build_model_code(payload.name, payload.model_name)
  44. if self.model_repository.get_by_code(code) is not None:
  45. raise ValueError(f"model code already exists: {code}")
  46. return self.model_repository.create(
  47. ModelDefinition(
  48. code=code,
  49. name=payload.name,
  50. provider_id=provider.id if provider is not None else None,
  51. provider_type=provider.provider_type if provider is not None else payload.provider_type,
  52. provider_base_url=provider.base_url if provider is not None else str(payload.provider_base_url),
  53. provider_api_key=provider.api_key if provider is not None else payload.provider_api_key,
  54. model_name=payload.model_name,
  55. status=payload.status,
  56. description=payload.description,
  57. capabilities_json=payload.capabilities_json,
  58. context_window=payload.context_window,
  59. max_output_tokens=payload.max_output_tokens,
  60. default_temperature=payload.default_temperature,
  61. timeout_seconds=payload.timeout_seconds,
  62. metadata_json=payload.metadata_json,
  63. )
  64. )
  65. def list_models(self) -> list[ModelDefinition]:
  66. return self.model_repository.list_all()
  67. def create_model_from_contract(self, payload: ModelCreateRequestDto) -> ModelDefinition:
  68. return self.create_model(
  69. ModelCreateRequest(
  70. name=payload.name,
  71. provider_id=payload.providerId,
  72. provider_type=payload.providerType,
  73. provider_base_url=payload.providerBaseUrl or "",
  74. provider_api_key=payload.providerApiKey,
  75. model_name=payload.modelName,
  76. description=payload.description,
  77. capabilities_json=payload.capabilities,
  78. context_window=payload.contextWindow,
  79. max_output_tokens=payload.maxOutputTokens,
  80. default_temperature=payload.defaultTemperature,
  81. timeout_seconds=payload.timeoutSeconds,
  82. metadata_json=payload.metadata))
  83. def update_model_from_contract(self, payload: ModelUpdateRequestDto) -> ModelDefinition | None:
  84. updates = payload.model_dump(exclude_unset=True)
  85. updates.pop("modelId", None)
  86. mapped_updates = {
  87. "name": updates.get("name"),
  88. "provider_id": updates.get("providerId"),
  89. "provider_type": updates.get("providerType"),
  90. "provider_base_url": updates.get("providerBaseUrl"),
  91. "provider_api_key": updates.get("providerApiKey"),
  92. "model_name": updates.get("modelName"),
  93. "description": updates.get("description"),
  94. "capabilities_json": updates.get("capabilities"),
  95. "context_window": updates.get("contextWindow"),
  96. "max_output_tokens": updates.get("maxOutputTokens"),
  97. "default_temperature": updates.get("defaultTemperature"),
  98. "timeout_seconds": updates.get("timeoutSeconds"),
  99. "metadata_json": updates.get("metadata"),
  100. }
  101. return self.update_model(
  102. model_id=payload.modelId,
  103. payload=ModelUpdateRequest(
  104. **{
  105. key: value
  106. for key, value in mapped_updates.items()
  107. if value is not None
  108. }))
  109. def update_model(self, model_id: str, payload: ModelUpdateRequest) -> ModelDefinition | None:
  110. entity = self.model_repository.get_by_id(model_id)
  111. if entity is None:
  112. return None
  113. updates = payload.model_dump(exclude_unset=True)
  114. if "code" in updates and updates["code"] != entity.code:
  115. existing = self.model_repository.get_by_code(str(updates["code"]))
  116. if existing is not None and existing.id != entity.id:
  117. raise ValueError(f"model code already exists: {updates['code']}")
  118. provider = self._get_provider_or_raise(updates.get("provider_id"))
  119. if provider is not None:
  120. updates["provider_id"] = provider.id
  121. updates["provider_type"] = provider.provider_type
  122. updates["provider_base_url"] = provider.base_url
  123. updates["provider_api_key"] = provider.api_key
  124. for key, value in updates.items():
  125. if key == "provider_base_url" and value is not None:
  126. value = str(value)
  127. setattr(entity, key, value)
  128. return self.model_repository.update(entity)
  129. def update_model_status(
  130. self,
  131. model_id: str,
  132. payload: ModelStatusUpdateRequest,
  133. ) -> ModelDefinition | None:
  134. entity = self.model_repository.get_by_id(model_id)
  135. if entity is None:
  136. return None
  137. entity.status = payload.status
  138. return self.model_repository.update(entity)
  139. def delete_model(self, model_id: str) -> bool:
  140. entity = self.model_repository.get_by_id(model_id)
  141. if entity is None:
  142. return False
  143. self.model_repository.delete(entity)
  144. return True
  145. def delete_model_from_contract(self, payload: ModelDeleteRequestDto) -> bool:
  146. return self.delete_model(payload.modelId)
  147. def create_chat_completion(
  148. self,
  149. payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
  150. configured_model = None
  151. if payload.model:
  152. configured_model = self.model_repository.get_active_for_request(payload.model)
  153. if configured_model is not None:
  154. configured_provider = self._resolve_model_provider(configured_model)
  155. resolved_payload = payload.model_copy(
  156. update={
  157. "model": configured_model.model_name,
  158. "temperature": payload.temperature
  159. if payload.temperature is not None
  160. else configured_model.default_temperature,
  161. "max_tokens": payload.max_tokens or configured_model.max_output_tokens,
  162. }
  163. )
  164. return self.provider_client.create_chat_completion(
  165. resolved_payload,
  166. provider_type=configured_provider.provider_type,
  167. provider_base_url=configured_provider.provider_base_url,
  168. provider_api_key=configured_provider.provider_api_key,
  169. timeout_seconds=configured_model.timeout_seconds,
  170. )
  171. resolved_payload = payload.model_copy(
  172. update={"model": payload.model or self.settings.default_model}
  173. )
  174. return self.provider_client.create_chat_completion(
  175. resolved_payload,
  176. provider_type=self.settings.provider_type)
  177. def test_model(self, model_id: str, payload: ModelTestRequest) -> ModelTestResponse | None:
  178. entity = self.model_repository.get_by_id(model_id)
  179. if entity is None:
  180. return None
  181. provider = self._resolve_model_provider(entity)
  182. messages = []
  183. if payload.system_prompt:
  184. messages.append({"role": "system", "content": payload.system_prompt})
  185. messages.append({"role": "user", "content": payload.prompt})
  186. response = self.provider_client.create_chat_completion(
  187. ChatCompletionRequestContract(
  188. model=entity.model_name,
  189. messages=messages,
  190. temperature=payload.temperature
  191. if payload.temperature is not None
  192. else entity.default_temperature,
  193. max_tokens=payload.max_tokens or entity.max_output_tokens,
  194. ),
  195. provider_type=provider.provider_type,
  196. provider_base_url=provider.provider_base_url,
  197. provider_api_key=provider.provider_api_key,
  198. timeout_seconds=entity.timeout_seconds,
  199. )
  200. from app.schemas.model import ModelResponse
  201. return ModelTestResponse(model=ModelResponse.from_entity(entity), response=response)
  202. def test_model_from_contract(self, payload: ModelTestRequestDto) -> ModelTestData | None:
  203. result = self.test_model(
  204. model_id=payload.modelId,
  205. payload=ModelTestRequest(
  206. prompt=payload.prompt,
  207. system_prompt=payload.systemPrompt,
  208. temperature=payload.temperature,
  209. max_tokens=payload.maxTokens))
  210. if result is None:
  211. return None
  212. entity = self.model_repository.get_by_id(payload.modelId)
  213. if entity is None:
  214. return None
  215. return ModelTestData(model=ModelDto.from_entity(entity), response=result.response)
  216. def list_providers(self) -> list[ModelProviderDefinition]:
  217. self._ensure_legacy_model_providers()
  218. return self.provider_repository.list_all()
  219. def create_provider(self, payload: ModelProviderCreateRequestDto) -> ModelProviderDefinition:
  220. return self.provider_repository.create(
  221. ModelProviderDefinition(
  222. name=payload.name,
  223. provider_type=payload.providerType,
  224. base_url=str(payload.baseUrl),
  225. api_key=payload.apiKey,
  226. models_json=[_to_snake_model_item(item) for item in payload.models],
  227. default_model=payload.defaultModel,
  228. extra_config_json=payload.extraConfig))
  229. def update_provider(
  230. self,
  231. payload: ModelProviderUpdateRequestDto) -> ModelProviderDefinition | None:
  232. entity = self.provider_repository.get_by_id(payload.providerId)
  233. if entity is None:
  234. return None
  235. updates = payload.model_dump(exclude_unset=True)
  236. updates.pop("providerId", None)
  237. for key, value in updates.items():
  238. if key == "baseUrl":
  239. entity.base_url = str(value) if value is not None else entity.base_url
  240. elif key == "apiKey":
  241. entity.api_key = value
  242. elif key == "defaultModel":
  243. entity.default_model = value
  244. elif key == "extraConfig":
  245. entity.extra_config_json = value
  246. elif key == "models":
  247. entity.models_json = [
  248. _to_snake_model_item(ModelItemDto(**item))
  249. if isinstance(item, dict)
  250. else _to_snake_model_item(item)
  251. for item in value or []
  252. ]
  253. elif key == "name":
  254. entity.name = value
  255. return self.provider_repository.update(entity)
  256. def delete_provider(self, payload: ModelProviderDeleteRequestDto) -> bool:
  257. entity = self.provider_repository.get_by_id(payload.providerId)
  258. if entity is None:
  259. return False
  260. self.provider_repository.delete(entity)
  261. return True
  262. def test_provider(self, payload: ModelProviderTestRequestDto) -> ModelProviderTestData | None:
  263. entity = self.provider_repository.get_by_id(payload.providerId)
  264. if entity is None:
  265. return None
  266. return ModelProviderTestData(
  267. success=True,
  268. message="Connection configuration is available.",
  269. latencyMs=0,
  270. modelList=[
  271. str(item.get("model_id") or item.get("modelId"))
  272. for item in entity.models_json or []
  273. if item.get("model_id") or item.get("modelId")
  274. ])
  275. def discover_models(self, payload: DiscoverModelsRequestDto) -> DiscoverModelsData:
  276. provider_type = payload.providerType
  277. if payload.providerId:
  278. provider = self.provider_repository.get_by_id(payload.providerId)
  279. if provider is not None:
  280. return DiscoverModelsData(
  281. providerType=provider.provider_type,
  282. models=ModelProviderDto.from_entity(provider).models)
  283. return DiscoverModelsData(
  284. providerType=provider_type or self.settings.provider_type,
  285. models=self._default_model_catalog(provider_type or self.settings.provider_type))
  286. def _default_model_catalog(self, provider_type: str) -> list[ModelItemDto]:
  287. catalogs = {
  288. "openai": [
  289. ModelItemDto(
  290. modelId="gpt-4.1-mini",
  291. displayName="GPT-4.1 Mini",
  292. modelType="chat",
  293. ownedBy="openai",
  294. contextWindow=1047576),
  295. ModelItemDto(
  296. modelId="text-embedding-3-small",
  297. displayName="Text Embedding 3 Small",
  298. modelType="embedding",
  299. ownedBy="openai"),
  300. ],
  301. "openai_compatible": [
  302. ModelItemDto(
  303. modelId="gpt-4.1-mini",
  304. displayName="OpenAI Compatible Chat",
  305. modelType="chat",
  306. contextWindow=128000),
  307. ModelItemDto(
  308. modelId="text-embedding-3-small",
  309. displayName="OpenAI Compatible Embedding",
  310. modelType="embedding"),
  311. ],
  312. "ollama": [
  313. ModelItemDto(
  314. modelId="llama3.1:8b",
  315. displayName="LLaMA 3.1 8B",
  316. modelType="chat",
  317. ownedBy="meta",
  318. contextWindow=131072),
  319. ModelItemDto(
  320. modelId="nomic-embed-text",
  321. displayName="Nomic Embed Text",
  322. modelType="embedding",
  323. ownedBy="nomic"),
  324. ],
  325. }
  326. return catalogs.get(provider_type, [])
  327. def _build_model_code(self, name: str, model_name: str) -> str:
  328. base = "".join(
  329. char.lower() if char.isalnum() else "_"
  330. for char in f"{name}_{model_name}"
  331. ).strip("_") or "model"
  332. candidate = base[:64]
  333. suffix = 1
  334. while self.model_repository.get_by_code(candidate) is not None:
  335. suffix_text = f"_{suffix}"
  336. candidate = f"{base[:64 - len(suffix_text)]}{suffix_text}"
  337. suffix += 1
  338. return candidate
  339. def _get_provider_or_raise(self, provider_id: str | None) -> ModelProviderDefinition | None:
  340. if provider_id is None:
  341. return None
  342. provider = self.provider_repository.get_by_id(provider_id)
  343. if provider is None:
  344. raise ValueError(f"model provider not found: {provider_id}")
  345. return provider
  346. def _resolve_model_provider(self, model: ModelDefinition) -> "_ResolvedModelProvider":
  347. provider = self.provider_repository.get_by_id(model.provider_id) if model.provider_id else None
  348. if provider is not None:
  349. return _ResolvedModelProvider(
  350. provider_type=provider.provider_type,
  351. provider_base_url=provider.base_url,
  352. provider_api_key=provider.api_key)
  353. return _ResolvedModelProvider(
  354. provider_type=model.provider_type,
  355. provider_base_url=model.provider_base_url,
  356. provider_api_key=model.provider_api_key)
  357. def _ensure_legacy_model_providers(self) -> None:
  358. legacy_models = [
  359. model
  360. for model in self.model_repository.list_all()
  361. if model.provider_id is None and model.provider_base_url
  362. ]
  363. for model in legacy_models:
  364. provider = self.provider_repository.get_by_connection(
  365. provider_type=model.provider_type,
  366. base_url=model.provider_base_url)
  367. if provider is None:
  368. provider = self.provider_repository.create(
  369. ModelProviderDefinition(
  370. name=self._build_provider_name(model.provider_type, model.provider_base_url),
  371. provider_type=model.provider_type,
  372. base_url=model.provider_base_url,
  373. api_key=model.provider_api_key,
  374. models_json=[],
  375. default_model=model.model_name,
  376. extra_config_json={"source": "legacy_model_backfill"}))
  377. model.provider_id = provider.id
  378. self._append_provider_model(provider=provider, model=model)
  379. self.model_repository.update(model)
  380. self.provider_repository.update(provider)
  381. def _append_provider_model(
  382. self,
  383. *,
  384. provider: ModelProviderDefinition,
  385. model: ModelDefinition) -> None:
  386. existing_items = list(provider.models_json or [])
  387. if any(item.get("model_id") == model.model_name for item in existing_items):
  388. return
  389. existing_items.append(
  390. {
  391. "model_id": model.model_name,
  392. "display_name": model.name,
  393. "model_type": "chat",
  394. })
  395. provider.models_json = existing_items
  396. def _build_provider_name(self, provider_type: str, base_url: str) -> str:
  397. label = provider_type.replace("_", " ").title()
  398. host = base_url.split("//")[-1].split("/")[0]
  399. return f"{label} - {host}" if host else label
  400. class _ResolvedModelProvider:
  401. def __init__(
  402. self,
  403. *,
  404. provider_type: str,
  405. provider_base_url: str,
  406. provider_api_key: str | None) -> None:
  407. self.provider_type = provider_type
  408. self.provider_base_url = provider_base_url
  409. self.provider_api_key = provider_api_key