services.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from core_domain import ChatCompletionRequestContract, ChatCompletionResponseContract
  2. from app.bootstrap.settings import ModelGatewayServiceSettings
  3. from app.db.models import ModelDefinition
  4. from app.domain.repositories import ModelDefinitionRepository
  5. from app.infrastructure.provider import ModelProviderClient
  6. from app.schemas.model import (
  7. ModelCreateRequest,
  8. ModelStatusUpdateRequest,
  9. ModelTestRequest,
  10. ModelTestResponse,
  11. ModelUpdateRequest,
  12. )
  13. class ModelGatewayApplicationService:
  14. def __init__(
  15. self,
  16. *,
  17. model_repository: ModelDefinitionRepository,
  18. provider_client: ModelProviderClient,
  19. settings: ModelGatewayServiceSettings) -> None:
  20. self.model_repository = model_repository
  21. self.provider_client = provider_client
  22. self.settings = settings
  23. def create_model(self, payload: ModelCreateRequest) -> ModelDefinition:
  24. if self.model_repository.get_by_code(payload.code) is not None:
  25. raise ValueError(f"model code already exists: {payload.code}")
  26. return self.model_repository.create(
  27. ModelDefinition(
  28. code=payload.code,
  29. name=payload.name,
  30. provider_type=payload.provider_type,
  31. provider_base_url=str(payload.provider_base_url),
  32. provider_api_key=payload.provider_api_key,
  33. model_name=payload.model_name,
  34. status=payload.status,
  35. description=payload.description,
  36. capabilities_json=payload.capabilities_json,
  37. context_window=payload.context_window,
  38. max_output_tokens=payload.max_output_tokens,
  39. default_temperature=payload.default_temperature,
  40. timeout_seconds=payload.timeout_seconds,
  41. metadata_json=payload.metadata_json,
  42. )
  43. )
  44. def list_models(self) -> list[ModelDefinition]:
  45. return self.model_repository.list_all()
  46. def update_model(self, model_id: str, payload: ModelUpdateRequest) -> ModelDefinition | None:
  47. entity = self.model_repository.get_by_id(model_id)
  48. if entity is None:
  49. return None
  50. updates = payload.model_dump(exclude_unset=True)
  51. if "code" in updates and updates["code"] != entity.code:
  52. existing = self.model_repository.get_by_code(str(updates["code"]))
  53. if existing is not None and existing.id != entity.id:
  54. raise ValueError(f"model code already exists: {updates['code']}")
  55. for key, value in updates.items():
  56. if key == "provider_base_url" and value is not None:
  57. value = str(value)
  58. setattr(entity, key, value)
  59. return self.model_repository.update(entity)
  60. def update_model_status(
  61. self,
  62. model_id: str,
  63. payload: ModelStatusUpdateRequest,
  64. ) -> ModelDefinition | None:
  65. entity = self.model_repository.get_by_id(model_id)
  66. if entity is None:
  67. return None
  68. entity.status = payload.status
  69. return self.model_repository.update(entity)
  70. def delete_model(self, model_id: str) -> bool:
  71. entity = self.model_repository.get_by_id(model_id)
  72. if entity is None:
  73. return False
  74. self.model_repository.delete(entity)
  75. return True
  76. def create_chat_completion(
  77. self,
  78. payload: ChatCompletionRequestContract) -> ChatCompletionResponseContract:
  79. configured_model = None
  80. if payload.model:
  81. configured_model = self.model_repository.get_active_for_request(payload.model)
  82. if configured_model is not None:
  83. resolved_payload = payload.model_copy(
  84. update={
  85. "model": configured_model.model_name,
  86. "temperature": payload.temperature
  87. if payload.temperature is not None
  88. else configured_model.default_temperature,
  89. "max_tokens": payload.max_tokens or configured_model.max_output_tokens,
  90. }
  91. )
  92. return self.provider_client.create_chat_completion(
  93. resolved_payload,
  94. provider_base_url=configured_model.provider_base_url,
  95. provider_api_key=configured_model.provider_api_key,
  96. timeout_seconds=configured_model.timeout_seconds,
  97. )
  98. resolved_payload = payload.model_copy(
  99. update={"model": payload.model or self.settings.default_model}
  100. )
  101. return self.provider_client.create_chat_completion(resolved_payload)
  102. def test_model(self, model_id: str, payload: ModelTestRequest) -> ModelTestResponse | None:
  103. entity = self.model_repository.get_by_id(model_id)
  104. if entity is None:
  105. return None
  106. messages = []
  107. if payload.system_prompt:
  108. messages.append({"role": "system", "content": payload.system_prompt})
  109. messages.append({"role": "user", "content": payload.prompt})
  110. response = self.provider_client.create_chat_completion(
  111. ChatCompletionRequestContract(
  112. model=entity.model_name,
  113. messages=messages,
  114. temperature=payload.temperature
  115. if payload.temperature is not None
  116. else entity.default_temperature,
  117. max_tokens=payload.max_tokens or entity.max_output_tokens,
  118. ),
  119. provider_base_url=entity.provider_base_url,
  120. provider_api_key=entity.provider_api_key,
  121. timeout_seconds=entity.timeout_seconds,
  122. )
  123. from app.schemas.model import ModelResponse
  124. return ModelTestResponse(model=ModelResponse.from_entity(entity), response=response)