| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596 |
- from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
- from app.bootstrap.settings import ModelGatewayServiceSettings
- from app.db.models import ModelDefinition, ModelProviderDefinition
- from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
- from app.infrastructure.provider import ModelProviderClient, ModelProviderClientError
- from app.schemas.model import (
- DiscoverModelsData,
- DiscoverModelsRequestDto,
- ModelCreateRequest,
- ModelCreateRequestDto,
- ModelDeleteRequestDto,
- ModelDto,
- ModelItemDto,
- ModelProviderCreateRequestDto,
- ModelProviderDeleteRequestDto,
- ModelProviderDto,
- ModelProviderTestData,
- ModelProviderTestRequestDto,
- ModelProviderUpdateRequestDto,
- ModelStatusUpdateRequest,
- ModelTestData,
- ModelTestRequest,
- ModelTestRequestDto,
- ModelTestResponse,
- ModelUpdateRequest,
- ModelUpdateRequestDto,
- _to_snake_model_item,
- )
- class ModelGatewayApplicationService:
- def __init__(
- self,
- *,
- model_repository: ModelDefinitionRepository,
- provider_repository: ModelProviderDefinitionRepository,
- provider_client: ModelProviderClient,
- settings: ModelGatewayServiceSettings) -> None:
- self.model_repository = model_repository
- self.provider_repository = provider_repository
- self.provider_client = provider_client
- self.settings = settings
- def create_model(self, payload: ModelCreateRequest) -> ModelDefinition:
- provider = self._get_provider_or_raise(payload.provider_id)
- if provider is not None:
- existing = self.model_repository.get_by_provider_model(
- provider_id=provider.id,
- model_name=payload.model_name)
- if existing is not None:
- existing.name = payload.name
- existing.provider_type = provider.provider_type
- existing.provider_base_url = provider.base_url
- existing.provider_api_key = provider.api_key
- existing.description = payload.description
- existing.capabilities_json = payload.capabilities_json
- existing.context_window = payload.context_window or existing.context_window
- existing.max_output_tokens = payload.max_output_tokens
- existing.default_temperature = payload.default_temperature
- existing.timeout_seconds = payload.timeout_seconds
- existing.metadata_json = {
- **(existing.metadata_json or {}),
- **payload.metadata_json,
- }
- return self.model_repository.update(existing)
- code = payload.code or self._build_model_code(payload.name, payload.model_name)
- if self.model_repository.get_by_code(code) is not None:
- raise ValueError(f"model code already exists: {code}")
- return self.model_repository.create(
- ModelDefinition(
- code=code,
- name=payload.name,
- provider_id=provider.id if provider is not None else None,
- provider_type=provider.provider_type if provider is not None else payload.provider_type,
- provider_base_url=provider.base_url if provider is not None else str(payload.provider_base_url),
- provider_api_key=provider.api_key if provider is not None else payload.provider_api_key,
- model_name=payload.model_name,
- status=payload.status,
- description=payload.description,
- capabilities_json=payload.capabilities_json,
- context_window=payload.context_window,
- max_output_tokens=payload.max_output_tokens,
- default_temperature=payload.default_temperature,
- timeout_seconds=payload.timeout_seconds,
- metadata_json=payload.metadata_json,
- )
- )
- def list_models(self) -> list[ModelDefinition]:
- return self.model_repository.list_all()
- def create_model_from_contract(self, payload: ModelCreateRequestDto) -> ModelDefinition:
- return self.create_model(
- ModelCreateRequest(
- name=payload.name,
- provider_id=payload.providerId,
- provider_type=payload.providerType,
- provider_base_url=payload.providerBaseUrl or "",
- provider_api_key=payload.providerApiKey,
- model_name=payload.modelName,
- description=payload.description,
- capabilities_json=payload.capabilities,
- context_window=payload.contextWindow,
- max_output_tokens=payload.maxOutputTokens,
- default_temperature=payload.defaultTemperature,
- timeout_seconds=payload.timeoutSeconds,
- metadata_json=payload.metadata))
- def update_model_from_contract(self, payload: ModelUpdateRequestDto) -> ModelDefinition | None:
- updates = payload.model_dump(exclude_unset=True)
- updates.pop("modelId", None)
- mapped_updates = {
- "name": updates.get("name"),
- "provider_id": updates.get("providerId"),
- "provider_type": updates.get("providerType"),
- "provider_base_url": updates.get("providerBaseUrl"),
- "provider_api_key": updates.get("providerApiKey"),
- "model_name": updates.get("modelName"),
- "description": updates.get("description"),
- "capabilities_json": updates.get("capabilities"),
- "context_window": updates.get("contextWindow"),
- "max_output_tokens": updates.get("maxOutputTokens"),
- "default_temperature": updates.get("defaultTemperature"),
- "timeout_seconds": updates.get("timeoutSeconds"),
- "metadata_json": updates.get("metadata"),
- }
- return self.update_model(
- model_id=payload.modelId,
- payload=ModelUpdateRequest(
- **{
- key: value
- for key, value in mapped_updates.items()
- if value is not None
- }))
- def update_model(self, model_id: str, payload: ModelUpdateRequest) -> ModelDefinition | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- updates = payload.model_dump(exclude_unset=True)
- if "code" in updates and updates["code"] != entity.code:
- existing = self.model_repository.get_by_code(str(updates["code"]))
- if existing is not None and existing.id != entity.id:
- raise ValueError(f"model code already exists: {updates['code']}")
- provider = self._get_provider_or_raise(updates.get("provider_id"))
- if provider is not None:
- updates["provider_id"] = provider.id
- updates["provider_type"] = provider.provider_type
- updates["provider_base_url"] = provider.base_url
- updates["provider_api_key"] = provider.api_key
- for key, value in updates.items():
- if key == "provider_base_url" and value is not None:
- value = str(value)
- setattr(entity, key, value)
- return self.model_repository.update(entity)
- def update_model_status(
- self,
- model_id: str,
- payload: ModelStatusUpdateRequest,
- ) -> ModelDefinition | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- entity.status = payload.status
- return self.model_repository.update(entity)
- def delete_model(self, model_id: str) -> bool:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return False
- self.model_repository.delete(entity)
- return True
- def delete_model_from_contract(self, payload: ModelDeleteRequestDto) -> bool:
- return self.delete_model(payload.modelId)
- def create_chat_completion(
- self,
- payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
- configured_model = None
- if payload.model:
- configured_model = self.model_repository.get_active_for_request(payload.model)
- if configured_model is not None:
- configured_provider = self._resolve_model_provider(configured_model)
- resolved_payload = payload.model_copy(
- update={
- "model": configured_model.model_name,
- "temperature": payload.temperature
- if payload.temperature is not None
- else configured_model.default_temperature,
- "max_tokens": payload.max_tokens or configured_model.max_output_tokens,
- }
- )
- return self.provider_client.create_chat_completion(
- resolved_payload,
- provider_type=configured_provider.provider_type,
- provider_base_url=configured_provider.provider_base_url,
- provider_api_key=configured_provider.provider_api_key,
- timeout_seconds=configured_model.timeout_seconds,
- )
- resolved_payload = payload.model_copy(
- update={"model": payload.model or self.settings.default_model}
- )
- return self.provider_client.create_chat_completion(
- resolved_payload,
- provider_type=self.settings.provider_type)
- def test_model(self, model_id: str, payload: ModelTestRequest) -> ModelTestResponse | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- provider = self._resolve_model_provider(entity)
- messages = []
- if payload.system_prompt:
- messages.append({"role": "system", "content": payload.system_prompt})
- messages.append({"role": "user", "content": payload.prompt})
- response = self.provider_client.create_chat_completion(
- ChatCompletionRequestContract(
- model=entity.model_name,
- messages=messages,
- temperature=payload.temperature
- if payload.temperature is not None
- else entity.default_temperature,
- max_tokens=payload.max_tokens or entity.max_output_tokens,
- ),
- provider_type=provider.provider_type,
- provider_base_url=provider.provider_base_url,
- provider_api_key=provider.provider_api_key,
- timeout_seconds=entity.timeout_seconds,
- )
- from app.schemas.model import ModelResponse
- return ModelTestResponse(model=ModelResponse.from_entity(entity), response=response)
- def test_model_from_contract(self, payload: ModelTestRequestDto) -> ModelTestData | None:
- result = self.test_model(
- model_id=payload.modelId,
- payload=ModelTestRequest(
- prompt=payload.prompt,
- system_prompt=payload.systemPrompt,
- temperature=payload.temperature,
- max_tokens=payload.maxTokens))
- if result is None:
- return None
- entity = self.model_repository.get_by_id(payload.modelId)
- if entity is None:
- return None
- return ModelTestData(model=ModelDto.from_entity(entity), response=result.response)
- def list_providers(self) -> list[ModelProviderDefinition]:
- self._ensure_legacy_model_providers()
- return self.provider_repository.list_all()
- def create_provider(self, payload: ModelProviderCreateRequestDto) -> ModelProviderDefinition:
- provider = self.provider_repository.create(
- ModelProviderDefinition(
- name=payload.name,
- provider_type=payload.providerType,
- base_url=str(payload.baseUrl),
- api_key=payload.apiKey,
- models_json=[_to_snake_model_item(item) for item in payload.models],
- default_model=payload.defaultModel,
- extra_config_json=payload.extraConfig))
- self._refresh_and_sync_provider_models(provider, raise_on_empty=False)
- return provider
- def update_provider(
- self,
- payload: ModelProviderUpdateRequestDto) -> ModelProviderDefinition | None:
- entity = self.provider_repository.get_by_id(payload.providerId)
- if entity is None:
- return None
- updates = payload.model_dump(exclude_unset=True)
- updates.pop("providerId", None)
- for key, value in updates.items():
- if key == "baseUrl":
- entity.base_url = str(value) if value is not None else entity.base_url
- elif key == "apiKey":
- entity.api_key = value
- elif key == "defaultModel":
- entity.default_model = value
- elif key == "extraConfig":
- entity.extra_config_json = value
- elif key == "models":
- entity.models_json = [
- _to_snake_model_item(ModelItemDto(**item))
- if isinstance(item, dict)
- else _to_snake_model_item(item)
- for item in value or []
- ]
- elif key == "name":
- entity.name = value
- updated = self.provider_repository.update(entity)
- self._refresh_and_sync_provider_models(updated, raise_on_empty=False)
- return updated
- def delete_provider(self, payload: ModelProviderDeleteRequestDto) -> bool:
- entity = self.provider_repository.get_by_id(payload.providerId)
- if entity is None:
- return False
- self.provider_repository.delete(entity)
- return True
- def test_provider(self, payload: ModelProviderTestRequestDto) -> ModelProviderTestData | None:
- entity = self.provider_repository.get_by_id(payload.providerId)
- if entity is None:
- return None
- try:
- models = self.provider_client.list_models(
- provider_type=entity.provider_type,
- provider_base_url=entity.base_url,
- provider_api_key=entity.api_key)
- except ModelProviderClientError:
- models = list(entity.models_json or [])
- if not models:
- raise
- return ModelProviderTestData(
- success=True,
- message="Connection configuration is available.",
- latencyMs=0,
- modelList=[
- str(item.get("modelId") or item.get("model_id"))
- for item in models
- if item.get("modelId") or item.get("model_id")
- ])
- def discover_models(self, payload: DiscoverModelsRequestDto) -> DiscoverModelsData:
- provider_type = payload.providerType
- if payload.providerId:
- provider = self.provider_repository.get_by_id(payload.providerId)
- if provider is not None:
- discovered = self._refresh_and_sync_provider_models(
- provider,
- raise_on_empty=True)
- return DiscoverModelsData(
- providerType=provider.provider_type,
- models=discovered)
- if payload.baseUrl:
- discovered = [
- ModelItemDto(**item)
- for item in self.provider_client.list_models(
- provider_type=provider_type,
- provider_base_url=str(payload.baseUrl),
- provider_api_key=payload.apiKey)
- ]
- return DiscoverModelsData(
- providerType=provider_type or self.settings.provider_type,
- models=discovered)
- return DiscoverModelsData(
- providerType=provider_type or self.settings.provider_type,
- models=self._default_model_catalog(provider_type or self.settings.provider_type))
- def _default_model_catalog(self, provider_type: str) -> list[ModelItemDto]:
- catalogs = {
- "openai": [
- ModelItemDto(
- modelId="gpt-4.1-mini",
- displayName="GPT-4.1 Mini",
- modelType="chat",
- ownedBy="openai",
- contextWindow=1047576),
- ModelItemDto(
- modelId="text-embedding-3-small",
- displayName="Text Embedding 3 Small",
- modelType="embedding",
- ownedBy="openai"),
- ],
- "openai_compatible": [
- ModelItemDto(
- modelId="gpt-4.1-mini",
- displayName="OpenAI Compatible Chat",
- modelType="chat",
- contextWindow=128000),
- ModelItemDto(
- modelId="text-embedding-3-small",
- displayName="OpenAI Compatible Embedding",
- modelType="embedding"),
- ],
- "ollama": [
- ModelItemDto(
- modelId="llama3.1:8b",
- displayName="LLaMA 3.1 8B",
- modelType="chat",
- ownedBy="meta",
- contextWindow=131072),
- ModelItemDto(
- modelId="nomic-embed-text",
- displayName="Nomic Embed Text",
- modelType="embedding",
- ownedBy="nomic"),
- ],
- "deepseek": [
- ModelItemDto(
- modelId="deepseek-chat",
- displayName="DeepSeek Chat",
- modelType="chat",
- ownedBy="deepseek",
- contextWindow=64000),
- ModelItemDto(
- modelId="deepseek-reasoner",
- displayName="DeepSeek Reasoner",
- modelType="reasoning",
- ownedBy="deepseek",
- contextWindow=64000),
- ],
- }
- return catalogs.get(provider_type, [])
- def _build_model_code(self, name: str, model_name: str) -> str:
- base = "".join(
- char.lower() if char.isalnum() else "_"
- for char in f"{name}_{model_name}"
- ).strip("_") or "model"
- candidate = base[:64]
- suffix = 1
- while self.model_repository.get_by_code(candidate) is not None:
- suffix_text = f"_{suffix}"
- candidate = f"{base[:64 - len(suffix_text)]}{suffix_text}"
- suffix += 1
- return candidate
- def _get_provider_or_raise(self, provider_id: str | None) -> ModelProviderDefinition | None:
- if provider_id is None:
- return None
- provider = self.provider_repository.get_by_id(provider_id)
- if provider is None:
- raise ValueError(f"model provider not found: {provider_id}")
- return provider
- def _resolve_model_provider(self, model: ModelDefinition) -> "_ResolvedModelProvider":
- provider = self.provider_repository.get_by_id(model.provider_id) if model.provider_id else None
- if provider is not None:
- return _ResolvedModelProvider(
- provider_type=provider.provider_type,
- provider_base_url=provider.base_url,
- provider_api_key=provider.api_key)
- return _ResolvedModelProvider(
- provider_type=model.provider_type,
- provider_base_url=model.provider_base_url,
- provider_api_key=model.provider_api_key)
- def _ensure_legacy_model_providers(self) -> None:
- legacy_models = [
- model
- for model in self.model_repository.list_all()
- if model.provider_id is None and model.provider_base_url
- ]
- for model in legacy_models:
- provider = self.provider_repository.get_by_connection(
- provider_type=model.provider_type,
- base_url=model.provider_base_url)
- if provider is None:
- provider = self.provider_repository.create(
- ModelProviderDefinition(
- name=self._build_provider_name(model.provider_type, model.provider_base_url),
- provider_type=model.provider_type,
- base_url=model.provider_base_url,
- api_key=model.provider_api_key,
- models_json=[],
- default_model=model.model_name,
- extra_config_json={"source": "legacy_model_backfill"}))
- model.provider_id = provider.id
- self._append_provider_model(provider=provider, model=model)
- self.model_repository.update(model)
- self.provider_repository.update(provider)
- def _append_provider_model(
- self,
- *,
- provider: ModelProviderDefinition,
- model: ModelDefinition) -> None:
- existing_items = list(provider.models_json or [])
- if any(item.get("model_id") == model.model_name for item in existing_items):
- return
- existing_items.append(
- {
- "model_id": model.model_name,
- "display_name": model.name,
- "model_type": "chat",
- })
- provider.models_json = existing_items
- def _sync_provider_models(
- self,
- *,
- provider: ModelProviderDefinition,
- models: list[ModelItemDto]) -> None:
- for item in models:
- model_name = item.modelId.strip()
- if not model_name:
- continue
- existing = self.model_repository.get_by_provider_model(
- provider_id=provider.id,
- model_name=model_name)
- capabilities = self._capabilities_for_model_item(item, provider.provider_type)
- if existing is None:
- self.model_repository.create(
- ModelDefinition(
- code=self._build_model_code(item.displayName or model_name, model_name),
- name=item.displayName or model_name,
- provider_id=provider.id,
- provider_type=provider.provider_type,
- provider_base_url=provider.base_url,
- provider_api_key=provider.api_key,
- model_name=model_name,
- status="active",
- description=None,
- capabilities_json=capabilities,
- context_window=item.contextWindow,
- max_output_tokens=None,
- default_temperature=None,
- timeout_seconds=60.0,
- metadata_json={"source": "provider_discovery"},
- )
- )
- continue
- existing.name = item.displayName or existing.name
- existing.provider_type = provider.provider_type
- existing.provider_base_url = provider.base_url
- existing.provider_api_key = provider.api_key
- existing.capabilities_json = capabilities
- existing.context_window = item.contextWindow or existing.context_window
- existing.metadata_json = {
- **(existing.metadata_json or {}),
- "source": "provider_discovery",
- }
- self.model_repository.update(existing)
- def _refresh_and_sync_provider_models(
- self,
- provider: ModelProviderDefinition,
- *,
- raise_on_empty: bool) -> list[ModelItemDto]:
- try:
- discovered = [
- ModelItemDto(**item)
- for item in self.provider_client.list_models(
- provider_type=provider.provider_type,
- provider_base_url=provider.base_url,
- provider_api_key=provider.api_key)
- ]
- except ModelProviderClientError:
- discovered = ModelProviderDto.from_entity(provider).models
- if not discovered and provider.provider_type == "deepseek":
- discovered = self._default_model_catalog("deepseek")
- if not discovered and raise_on_empty:
- raise
- if not discovered:
- return []
- provider.models_json = [_to_snake_model_item(item) for item in discovered]
- if provider.default_model is None:
- provider.default_model = discovered[0].modelId
- self.provider_repository.update(provider)
- self._sync_provider_models(provider=provider, models=discovered)
- return discovered
- def _capabilities_for_model_item(
- self,
- item: ModelItemDto,
- provider_type: str) -> list[str]:
- model_type = item.modelType
- capabilities: set[str] = set()
- if model_type == "reasoning":
- capabilities.update(["chat", "reasoning"])
- elif model_type in {"embedding", "image", "audio", "video", "rerank", "moderation"}:
- capabilities.add(model_type)
- else:
- capabilities.add("chat")
- if provider_type in {"openai", "anthropic", "deepseek", "openai_compatible"} and "chat" in capabilities:
- capabilities.add("tools")
- return sorted(capabilities)
- def _build_provider_name(self, provider_type: str, base_url: str) -> str:
- label = provider_type.replace("_", " ").title()
- host = base_url.split("//")[-1].split("/")[0]
- return f"{label} - {host}" if host else label
- class _ResolvedModelProvider:
- def __init__(
- self,
- *,
- provider_type: str,
- provider_base_url: str,
- provider_api_key: str | None) -> None:
- self.provider_type = provider_type
- self.provider_base_url = provider_base_url
- self.provider_api_key = provider_api_key
|