| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
- from app.bootstrap.settings import ModelGatewayServiceSettings
- from app.db.models import ModelDefinition
- from app.domain.repositories import ModelDefinitionRepository
- from app.infrastructure.provider import ModelProviderClient
- from app.schemas.model import (
- ModelCreateRequest,
- ModelStatusUpdateRequest,
- ModelTestRequest,
- ModelTestResponse,
- ModelUpdateRequest,
- )
- class ModelGatewayApplicationService:
- def __init__(
- self,
- *,
- model_repository: ModelDefinitionRepository,
- provider_client: ModelProviderClient,
- settings: ModelGatewayServiceSettings) -> None:
- self.model_repository = model_repository
- self.provider_client = provider_client
- self.settings = settings
- def create_model(self, payload: ModelCreateRequest) -> ModelDefinition:
- if self.model_repository.get_by_code(payload.code) is not None:
- raise ValueError(f"model code already exists: {payload.code}")
- return self.model_repository.create(
- ModelDefinition(
- code=payload.code,
- name=payload.name,
- provider_type=payload.provider_type,
- provider_base_url=str(payload.provider_base_url),
- provider_api_key=payload.provider_api_key,
- model_name=payload.model_name,
- status=payload.status,
- description=payload.description,
- capabilities_json=payload.capabilities_json,
- context_window=payload.context_window,
- max_output_tokens=payload.max_output_tokens,
- default_temperature=payload.default_temperature,
- timeout_seconds=payload.timeout_seconds,
- metadata_json=payload.metadata_json,
- )
- )
- def list_models(self) -> list[ModelDefinition]:
- return self.model_repository.list_all()
- def update_model(self, model_id: str, payload: ModelUpdateRequest) -> ModelDefinition | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- updates = payload.model_dump(exclude_unset=True)
- if "code" in updates and updates["code"] != entity.code:
- existing = self.model_repository.get_by_code(str(updates["code"]))
- if existing is not None and existing.id != entity.id:
- raise ValueError(f"model code already exists: {updates['code']}")
- for key, value in updates.items():
- if key == "provider_base_url" and value is not None:
- value = str(value)
- setattr(entity, key, value)
- return self.model_repository.update(entity)
- def update_model_status(
- self,
- model_id: str,
- payload: ModelStatusUpdateRequest,
- ) -> ModelDefinition | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- entity.status = payload.status
- return self.model_repository.update(entity)
- def delete_model(self, model_id: str) -> bool:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return False
- self.model_repository.delete(entity)
- return True
- def create_chat_completion(
- self,
- payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
- configured_model = None
- if payload.model:
- configured_model = self.model_repository.get_active_for_request(payload.model)
- if configured_model is not None:
- resolved_payload = payload.model_copy(
- update={
- "model": configured_model.model_name,
- "temperature": payload.temperature
- if payload.temperature is not None
- else configured_model.default_temperature,
- "max_tokens": payload.max_tokens or configured_model.max_output_tokens,
- }
- )
- return self.provider_client.create_chat_completion(
- resolved_payload,
- provider_base_url=configured_model.provider_base_url,
- provider_api_key=configured_model.provider_api_key,
- timeout_seconds=configured_model.timeout_seconds,
- )
- resolved_payload = payload.model_copy(
- update={"model": payload.model or self.settings.default_model}
- )
- return self.provider_client.create_chat_completion(resolved_payload)
- def test_model(self, model_id: str, payload: ModelTestRequest) -> ModelTestResponse | None:
- entity = self.model_repository.get_by_id(model_id)
- if entity is None:
- return None
- messages = []
- if payload.system_prompt:
- messages.append({"role": "system", "content": payload.system_prompt})
- messages.append({"role": "user", "content": payload.prompt})
- response = self.provider_client.create_chat_completion(
- ChatCompletionRequestContract(
- model=entity.model_name,
- messages=messages,
- temperature=payload.temperature
- if payload.temperature is not None
- else entity.default_temperature,
- max_tokens=payload.max_tokens or entity.max_output_tokens,
- ),
- provider_base_url=entity.provider_base_url,
- provider_api_key=entity.provider_api_key,
- timeout_seconds=entity.timeout_seconds,
- )
- from app.schemas.model import ModelResponse
- return ModelTestResponse(model=ModelResponse.from_entity(entity), response=response)
|