| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
- from app.bootstrap.settings import ModelGatewayServiceSettings
- from app.db.models import ModelDefinition, ModelProviderDefinition
- from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
- from app.infrastructure.provider import ModelProviderClient
- from app.schemas.model import (
- DiscoverModelsData,
- DiscoverModelsRequestDto,
- ModelCreateRequest,
- ModelCreateRequestDto,
- ModelDeleteRequestDto,
- ModelDto,
- ModelItemDto,
- ModelProviderCreateRequestDto,
- ModelProviderDeleteRequestDto,
- ModelProviderDto,
- ModelProviderTestData,
- ModelProviderTestRequestDto,
- ModelProviderUpdateRequestDto,
- ModelStatusUpdateRequest,
- ModelTestData,
- ModelTestRequest,
- ModelTestRequestDto,
- ModelTestResponse,
- ModelUpdateRequest,
- ModelUpdateRequestDto,
- _to_snake_model_item,
- )
- class ModelGatewayApplicationService:
- def __init__(
- self,
- *,
- model_repository: ModelDefinitionRepository,
- provider_repository: ModelProviderDefinitionRepository,
- provider_client: ModelProviderClient,
- settings: ModelGatewayServiceSettings) -> None:
- self.model_repository = model_repository
- self.provider_repository = provider_repository
- self.provider_client = provider_client
- self.settings = settings
- def create_model(self, payload: ModelCreateRequest) -> ModelDefinition:
- provider = self._get_provider_or_raise(payload.provider_id)
- code = payload.code or self._build_model_code(payload.name, payload.model_name)
- if self.model_repository.get_by_code(code) is not None:
- raise ValueError(f"model code already exists: {code}")
- return self.model_repository.create(
- ModelDefinition(
- code=code,
- name=payload.name,
- provider_id=provider.id if provider is not None else None,
- provider_type=provider.provider_type if provider is not None else payload.provider_type,
- provider_base_url=provider.base_url if provider is not None else str(payload.provider_base_url),
- provider_api_key=provider.api_key if provider is not None else payload.provider_api_key,
- model_name=payload.model_name,
- status=payload.status,
- description=payload.description,
- capabilities_json=payload.capabilities_json,
- context_window=payload.context_window,
- max_output_tokens=payload.max_output_tokens,
- default_temperature=payload.default_temperature,
- timeout_seconds=payload.timeout_seconds,
- metadata_json=payload.metadata_json,
- )
- )
- def list_models(self) -> list[ModelDefinition]:
- return self.model_repository.list_all()
- def create_model_from_contract(self, payload: ModelCreateRequestDto) -> ModelDefinition:
- return self.create_model(
- ModelCreateRequest(
- name=payload.name,
- provider_id=payload.providerId,
- provider_type=payload.providerType,
- provider_base_url=payload.providerBaseUrl or "",
- provider_api_key=payload.providerApiKey,
- model_name=payload.modelName,
- description=payload.description,
- capabilities_json=payload.capabilities,
- context_window=payload.contextWindow,
- max_output_tokens=payload.maxOutputTokens,
- default_temperature=payload.defaultTemperature,
- timeout_seconds=payload.timeoutSeconds,
- metadata_json=payload.metadata))
- def update_model_from_contract(self, payload: ModelUpdateRequestDto) -> ModelDefinition | None:
- updates = payload.model_dump(exclude_unset=True)
- updates.pop("modelId", None)
- mapped_updates = {
- "name": updates.get("name"),
- "provider_id": updates.get("providerId"),
- "provider_type": updates.get("providerType"),
- "provider_base_url": updates.get("providerBaseUrl"),
- "provider_api_key": updates.get("providerApiKey"),
- "model_name": updates.get("modelName"),
- "description": updates.get("description"),
- "capabilities_json": updates.get("capabilities"),
- "context_window": updates.get("contextWindow"),
- "max_output_tokens": updates.get("maxOutputTokens"),
- "default_temperature": updates.get("defaultTemperature"),
- "timeout_seconds": updates.get("timeoutSeconds"),
- "metadata_json": updates.get("metadata"),
- }
- return self.update_model(
- model_id=payload.modelId,
- payload=ModelUpdateRequest(
- **{
- key: value
- for key, value in mapped_updates.items()
- if value is not None
- }))
- def update_model(self, model_id: str, payload: ModelUpdateRequest) -> ModelDefinition | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- updates = payload.model_dump(exclude_unset=True)
- if "code" in updates and updates["code"] != entity.code:
- existing = self.model_repository.get_by_code(str(updates["code"]))
- if existing is not None and existing.id != entity.id:
- raise ValueError(f"model code already exists: {updates['code']}")
- provider = self._get_provider_or_raise(updates.get("provider_id"))
- if provider is not None:
- updates["provider_id"] = provider.id
- updates["provider_type"] = provider.provider_type
- updates["provider_base_url"] = provider.base_url
- updates["provider_api_key"] = provider.api_key
- for key, value in updates.items():
- if key == "provider_base_url" and value is not None:
- value = str(value)
- setattr(entity, key, value)
- return self.model_repository.update(entity)
- def update_model_status(
- self,
- model_id: str,
- payload: ModelStatusUpdateRequest,
- ) -> ModelDefinition | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- entity.status = payload.status
- return self.model_repository.update(entity)
- def delete_model(self, model_id: str) -> bool:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return False
- self.model_repository.delete(entity)
- return True
- def delete_model_from_contract(self, payload: ModelDeleteRequestDto) -> bool:
- return self.delete_model(payload.modelId)
- def create_chat_completion(
- self,
- payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
- configured_model = None
- if payload.model:
- configured_model = self.model_repository.get_active_for_request(payload.model)
- if configured_model is not None:
- configured_provider = self._resolve_model_provider(configured_model)
- resolved_payload = payload.model_copy(
- update={
- "model": configured_model.model_name,
- "temperature": payload.temperature
- if payload.temperature is not None
- else configured_model.default_temperature,
- "max_tokens": payload.max_tokens or configured_model.max_output_tokens,
- }
- )
- return self.provider_client.create_chat_completion(
- resolved_payload,
- provider_type=configured_provider.provider_type,
- provider_base_url=configured_provider.provider_base_url,
- provider_api_key=configured_provider.provider_api_key,
- timeout_seconds=configured_model.timeout_seconds,
- )
- resolved_payload = payload.model_copy(
- update={"model": payload.model or self.settings.default_model}
- )
- return self.provider_client.create_chat_completion(
- resolved_payload,
- provider_type=self.settings.provider_type)
- def test_model(self, model_id: str, payload: ModelTestRequest) -> ModelTestResponse | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- provider = self._resolve_model_provider(entity)
- messages = []
- if payload.system_prompt:
- messages.append({"role": "system", "content": payload.system_prompt})
- messages.append({"role": "user", "content": payload.prompt})
- response = self.provider_client.create_chat_completion(
- ChatCompletionRequestContract(
- model=entity.model_name,
- messages=messages,
- temperature=payload.temperature
- if payload.temperature is not None
- else entity.default_temperature,
- max_tokens=payload.max_tokens or entity.max_output_tokens,
- ),
- provider_type=provider.provider_type,
- provider_base_url=provider.provider_base_url,
- provider_api_key=provider.provider_api_key,
- timeout_seconds=entity.timeout_seconds,
- )
- from app.schemas.model import ModelResponse
- return ModelTestResponse(model=ModelResponse.from_entity(entity), response=response)
- def test_model_from_contract(self, payload: ModelTestRequestDto) -> ModelTestData | None:
- result = self.test_model(
- model_id=payload.modelId,
- payload=ModelTestRequest(
- prompt=payload.prompt,
- system_prompt=payload.systemPrompt,
- temperature=payload.temperature,
- max_tokens=payload.maxTokens))
- if result is None:
- return None
- entity = self.model_repository.get_by_id(payload.modelId)
- if entity is None:
- return None
- return ModelTestData(model=ModelDto.from_entity(entity), response=result.response)
- def list_providers(self) -> list[ModelProviderDefinition]:
- self._ensure_legacy_model_providers()
- return self.provider_repository.list_all()
- def create_provider(self, payload: ModelProviderCreateRequestDto) -> ModelProviderDefinition:
- return self.provider_repository.create(
- ModelProviderDefinition(
- name=payload.name,
- provider_type=payload.providerType,
- base_url=str(payload.baseUrl),
- api_key=payload.apiKey,
- models_json=[_to_snake_model_item(item) for item in payload.models],
- default_model=payload.defaultModel,
- extra_config_json=payload.extraConfig))
- def update_provider(
- self,
- payload: ModelProviderUpdateRequestDto) -> ModelProviderDefinition | None:
- entity = self.provider_repository.get_by_id(payload.providerId)
- if entity is None:
- return None
- updates = payload.model_dump(exclude_unset=True)
- updates.pop("providerId", None)
- for key, value in updates.items():
- if key == "baseUrl":
- entity.base_url = str(value) if value is not None else entity.base_url
- elif key == "apiKey":
- entity.api_key = value
- elif key == "defaultModel":
- entity.default_model = value
- elif key == "extraConfig":
- entity.extra_config_json = value
- elif key == "models":
- entity.models_json = [
- _to_snake_model_item(ModelItemDto(**item))
- if isinstance(item, dict)
- else _to_snake_model_item(item)
- for item in value or []
- ]
- elif key == "name":
- entity.name = value
- return self.provider_repository.update(entity)
- def delete_provider(self, payload: ModelProviderDeleteRequestDto) -> bool:
- entity = self.provider_repository.get_by_id(payload.providerId)
- if entity is None:
- return False
- self.provider_repository.delete(entity)
- return True
- def test_provider(self, payload: ModelProviderTestRequestDto) -> ModelProviderTestData | None:
- entity = self.provider_repository.get_by_id(payload.providerId)
- if entity is None:
- return None
- return ModelProviderTestData(
- success=True,
- message="Connection configuration is available.",
- latencyMs=0,
- modelList=[
- str(item.get("model_id") or item.get("modelId"))
- for item in entity.models_json or []
- if item.get("model_id") or item.get("modelId")
- ])
- def discover_models(self, payload: DiscoverModelsRequestDto) -> DiscoverModelsData:
- provider_type = payload.providerType
- if payload.providerId:
- provider = self.provider_repository.get_by_id(payload.providerId)
- if provider is not None:
- return DiscoverModelsData(
- providerType=provider.provider_type,
- models=ModelProviderDto.from_entity(provider).models)
- return DiscoverModelsData(
- providerType=provider_type or self.settings.provider_type,
- models=self._default_model_catalog(provider_type or self.settings.provider_type))
- def _default_model_catalog(self, provider_type: str) -> list[ModelItemDto]:
- catalogs = {
- "openai": [
- ModelItemDto(
- modelId="gpt-4.1-mini",
- displayName="GPT-4.1 Mini",
- modelType="chat",
- ownedBy="openai",
- contextWindow=1047576),
- ModelItemDto(
- modelId="text-embedding-3-small",
- displayName="Text Embedding 3 Small",
- modelType="embedding",
- ownedBy="openai"),
- ],
- "openai_compatible": [
- ModelItemDto(
- modelId="gpt-4.1-mini",
- displayName="OpenAI Compatible Chat",
- modelType="chat",
- contextWindow=128000),
- ModelItemDto(
- modelId="text-embedding-3-small",
- displayName="OpenAI Compatible Embedding",
- modelType="embedding"),
- ],
- "ollama": [
- ModelItemDto(
- modelId="llama3.1:8b",
- displayName="LLaMA 3.1 8B",
- modelType="chat",
- ownedBy="meta",
- contextWindow=131072),
- ModelItemDto(
- modelId="nomic-embed-text",
- displayName="Nomic Embed Text",
- modelType="embedding",
- ownedBy="nomic"),
- ],
- }
- return catalogs.get(provider_type, [])
- def _build_model_code(self, name: str, model_name: str) -> str:
- base = "".join(
- char.lower() if char.isalnum() else "_"
- for char in f"{name}_{model_name}"
- ).strip("_") or "model"
- candidate = base[:64]
- suffix = 1
- while self.model_repository.get_by_code(candidate) is not None:
- suffix_text = f"_{suffix}"
- candidate = f"{base[:64 - len(suffix_text)]}{suffix_text}"
- suffix += 1
- return candidate
- def _get_provider_or_raise(self, provider_id: str | None) -> ModelProviderDefinition | None:
- if provider_id is None:
- return None
- provider = self.provider_repository.get_by_id(provider_id)
- if provider is None:
- raise ValueError(f"model provider not found: {provider_id}")
- return provider
- def _resolve_model_provider(self, model: ModelDefinition) -> "_ResolvedModelProvider":
- provider = self.provider_repository.get_by_id(model.provider_id) if model.provider_id else None
- if provider is not None:
- return _ResolvedModelProvider(
- provider_type=provider.provider_type,
- provider_base_url=provider.base_url,
- provider_api_key=provider.api_key)
- return _ResolvedModelProvider(
- provider_type=model.provider_type,
- provider_base_url=model.provider_base_url,
- provider_api_key=model.provider_api_key)
- def _ensure_legacy_model_providers(self) -> None:
- legacy_models = [
- model
- for model in self.model_repository.list_all()
- if model.provider_id is None and model.provider_base_url
- ]
- for model in legacy_models:
- provider = self.provider_repository.get_by_connection(
- provider_type=model.provider_type,
- base_url=model.provider_base_url)
- if provider is None:
- provider = self.provider_repository.create(
- ModelProviderDefinition(
- name=self._build_provider_name(model.provider_type, model.provider_base_url),
- provider_type=model.provider_type,
- base_url=model.provider_base_url,
- api_key=model.provider_api_key,
- models_json=[],
- default_model=model.model_name,
- extra_config_json={"source": "legacy_model_backfill"}))
- model.provider_id = provider.id
- self._append_provider_model(provider=provider, model=model)
- self.model_repository.update(model)
- self.provider_repository.update(provider)
- def _append_provider_model(
- self,
- *,
- provider: ModelProviderDefinition,
- model: ModelDefinition) -> None:
- existing_items = list(provider.models_json or [])
- if any(item.get("model_id") == model.model_name for item in existing_items):
- return
- existing_items.append(
- {
- "model_id": model.model_name,
- "display_name": model.name,
- "model_type": "chat",
- })
- provider.models_json = existing_items
- def _build_provider_name(self, provider_type: str, base_url: str) -> str:
- label = provider_type.replace("_", " ").title()
- host = base_url.split("//")[-1].split("/")[0]
- return f"{label} - {host}" if host else label
- class _ResolvedModelProvider:
- def __init__(
- self,
- *,
- provider_type: str,
- provider_base_url: str,
- provider_api_key: str | None) -> None:
- self.provider_type = provider_type
- self.provider_base_url = provider_base_url
- self.provider_api_key = provider_api_key
|