routes.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from datetime import datetime
  2. from typing import Annotated, TypeVar
  3. from core_domain import (
  4. ChatCompletionRequestContract,
  5. ChatCompletionResponseContract,
  6. ServiceHealth,
  7. )
  8. from fastapi import APIRouter, Depends, HTTPException
  9. from sqlalchemy import text
  10. from sqlalchemy.orm import Session
  11. from app.application.services import ModelGatewayApplicationService
  12. from app.bootstrap.settings import ModelGatewayServiceSettings
  13. from app.db.session import get_db
  14. from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
  15. from app.infrastructure.provider import ModelProviderClient, ModelProviderClientError
  16. from app.schemas.model import (
  17. ApiResponse,
  18. DeleteData,
  19. DiscoverModelsData,
  20. DiscoverModelsRequestDto,
  21. ModelCreateRequest,
  22. ModelCreateRequestDto,
  23. ModelDeleteRequestDto,
  24. ModelDto,
  25. ModelProviderCreateRequestDto,
  26. ModelProviderDeleteRequestDto,
  27. ModelProviderDto,
  28. ModelProviderTestData,
  29. ModelProviderTestRequestDto,
  30. ModelProviderUpdateRequestDto,
  31. ModelResponse,
  32. ModelStatusUpdateRequest,
  33. ModelTestData,
  34. ModelTestRequest,
  35. ModelTestRequestDto,
  36. ModelTestResponse,
  37. ModelUpdateRequest,
  38. ModelUpdateRequestDto,
  39. PageRequest,
  40. PageResult,
  41. )
  42. router = APIRouter()
  43. DbSession = Annotated[Session, Depends(get_db)]
  44. T = TypeVar("T")
  45. def get_model_gateway_settings() -> ModelGatewayServiceSettings:
  46. return ModelGatewayServiceSettings()
  47. def get_model_gateway_application_service(
  48. db: DbSession,
  49. settings: Annotated[
  50. ModelGatewayServiceSettings,
  51. Depends(get_model_gateway_settings),
  52. ]) -> ModelGatewayApplicationService:
  53. return ModelGatewayApplicationService(
  54. model_repository=ModelDefinitionRepository(db),
  55. provider_repository=ModelProviderDefinitionRepository(db),
  56. provider_client=ModelProviderClient(settings=settings),
  57. settings=settings)
  58. ModelServiceDep = Annotated[
  59. ModelGatewayApplicationService,
  60. Depends(get_model_gateway_application_service),
  61. ]
  62. def ok(data: T) -> ApiResponse[T]:
  63. return ApiResponse[T](
  64. data=data,
  65. requestId="",
  66. serverTime=datetime.utcnow())
  67. @router.get("/health", response_model=ServiceHealth)
  68. def health_check(
  69. db: DbSession,
  70. settings: Annotated[
  71. ModelGatewayServiceSettings,
  72. Depends(get_model_gateway_settings),
  73. ]) -> ServiceHealth:
  74. db.execute(text("SELECT 1"))
  75. provider_status = "configured" if settings.provider_base_url else "missing"
  76. return ServiceHealth(service="model-gateway-service", status="ok", database=provider_status)
  77. @router.post("", response_model=ModelResponse)
  78. def create_model(
  79. payload: ModelCreateRequest,
  80. service: ModelServiceDep,
  81. ) -> ModelResponse:
  82. try:
  83. entity = service.create_model(payload)
  84. except ValueError as exc:
  85. raise HTTPException(status_code=422, detail=str(exc)) from exc
  86. return ModelResponse.from_entity(entity)
  87. @router.get("", response_model=list[ModelResponse])
  88. def list_models(
  89. service: ModelServiceDep,
  90. ) -> list[ModelResponse]:
  91. return [ModelResponse.from_entity(item) for item in service.list_models()]
  92. @router.patch("/{model_id}", response_model=ModelResponse)
  93. def update_model(
  94. model_id: str,
  95. payload: ModelUpdateRequest,
  96. service: ModelServiceDep,
  97. ) -> ModelResponse:
  98. try:
  99. entity = service.update_model(model_id=model_id, payload=payload)
  100. except ValueError as exc:
  101. raise HTTPException(status_code=422, detail=str(exc)) from exc
  102. if entity is None:
  103. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  104. return ModelResponse.from_entity(entity)
  105. @router.patch("/{model_id}/status", response_model=ModelResponse)
  106. def update_model_status(
  107. model_id: str,
  108. payload: ModelStatusUpdateRequest,
  109. service: ModelServiceDep,
  110. ) -> ModelResponse:
  111. entity = service.update_model_status(model_id=model_id, payload=payload)
  112. if entity is None:
  113. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  114. return ModelResponse.from_entity(entity)
  115. @router.delete("/{model_id}", status_code=204)
  116. def delete_model(
  117. model_id: str,
  118. service: ModelServiceDep,
  119. ) -> None:
  120. if not service.delete_model(model_id):
  121. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  122. @router.post("/{model_id}/test", response_model=ModelTestResponse)
  123. def test_model(
  124. model_id: str,
  125. payload: ModelTestRequest,
  126. service: ModelServiceDep,
  127. ) -> ModelTestResponse:
  128. try:
  129. result = service.test_model(model_id=model_id, payload=payload)
  130. except ModelProviderClientError as exc:
  131. raise HTTPException(status_code=502, detail=str(exc)) from exc
  132. if result is None:
  133. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  134. return result
  135. @router.post("/chat-completions", response_model=ChatCompletionResponseContract)
  136. def create_chat_completion(
  137. payload: ChatCompletionRequestContract,
  138. service: ModelServiceDep) -> ChatCompletionResponseContract:
  139. try:
  140. return service.create_chat_completion(payload)
  141. except ModelProviderClientError as exc:
  142. raise HTTPException(status_code=502, detail=str(exc)) from exc
  143. @router.post("/list", response_model=ApiResponse[PageResult[ModelDto]])
  144. def list_models_contract(
  145. payload: PageRequest,
  146. service: ModelServiceDep) -> ApiResponse[PageResult[ModelDto]]:
  147. keyword = (payload.keyword or "").lower().strip()
  148. items = [
  149. item
  150. for item in service.list_models()
  151. if not keyword
  152. or keyword in item.name.lower()
  153. or keyword in item.model_name.lower()
  154. or keyword in item.provider_type.lower()
  155. ]
  156. page_items = items[payload.offset:payload.offset + payload.pageSize]
  157. return ok(
  158. PageResult[ModelDto].from_items(
  159. items=[ModelDto.from_entity(item) for item in page_items],
  160. total=len(items),
  161. page=payload.page,
  162. page_size=payload.pageSize))
  163. @router.post("/create", response_model=ApiResponse[ModelDto])
  164. def create_model_contract(
  165. payload: ModelCreateRequestDto,
  166. service: ModelServiceDep) -> ApiResponse[ModelDto]:
  167. try:
  168. entity = service.create_model_from_contract(payload)
  169. except ValueError as exc:
  170. raise HTTPException(status_code=422, detail=str(exc)) from exc
  171. return ok(ModelDto.from_entity(entity))
  172. @router.post("/update", response_model=ApiResponse[ModelDto])
  173. def update_model_contract(
  174. payload: ModelUpdateRequestDto,
  175. service: ModelServiceDep) -> ApiResponse[ModelDto]:
  176. try:
  177. entity = service.update_model_from_contract(payload)
  178. except ValueError as exc:
  179. raise HTTPException(status_code=422, detail=str(exc)) from exc
  180. if entity is None:
  181. raise HTTPException(status_code=404, detail=f"model not found: {payload.modelId}")
  182. return ok(ModelDto.from_entity(entity))
  183. @router.post("/delete", response_model=ApiResponse[DeleteData])
  184. def delete_model_contract(
  185. payload: ModelDeleteRequestDto,
  186. service: ModelServiceDep) -> ApiResponse[DeleteData]:
  187. deleted = service.delete_model_from_contract(payload)
  188. return ok(DeleteData(deleted=deleted, modelId=payload.modelId))
  189. @router.post("/test", response_model=ApiResponse[ModelTestData])
  190. def test_model_contract(
  191. payload: ModelTestRequestDto,
  192. service: ModelServiceDep) -> ApiResponse[ModelTestData]:
  193. try:
  194. result = service.test_model_from_contract(payload)
  195. except ModelProviderClientError as exc:
  196. raise HTTPException(status_code=502, detail=str(exc)) from exc
  197. if result is None:
  198. raise HTTPException(status_code=404, detail=f"model not found: {payload.modelId}")
  199. return ok(result)
  200. @router.post("/providers/list", response_model=ApiResponse[PageResult[ModelProviderDto]])
  201. def list_model_providers_contract(
  202. payload: PageRequest,
  203. service: ModelServiceDep) -> ApiResponse[PageResult[ModelProviderDto]]:
  204. keyword = (payload.keyword or "").lower().strip()
  205. items = [
  206. item
  207. for item in service.list_providers()
  208. if not keyword
  209. or keyword in item.name.lower()
  210. or keyword in item.provider_type.lower()
  211. or keyword in item.base_url.lower()
  212. ]
  213. page_items = items[payload.offset:payload.offset + payload.pageSize]
  214. return ok(
  215. PageResult[ModelProviderDto].from_items(
  216. items=[ModelProviderDto.from_entity(item) for item in page_items],
  217. total=len(items),
  218. page=payload.page,
  219. page_size=payload.pageSize))
  220. @router.post("/providers/create", response_model=ApiResponse[ModelProviderDto])
  221. def create_model_provider_contract(
  222. payload: ModelProviderCreateRequestDto,
  223. service: ModelServiceDep) -> ApiResponse[ModelProviderDto]:
  224. entity = service.create_provider(payload)
  225. return ok(ModelProviderDto.from_entity(entity))
  226. @router.post("/providers/update", response_model=ApiResponse[ModelProviderDto])
  227. def update_model_provider_contract(
  228. payload: ModelProviderUpdateRequestDto,
  229. service: ModelServiceDep) -> ApiResponse[ModelProviderDto]:
  230. entity = service.update_provider(payload)
  231. if entity is None:
  232. raise HTTPException(status_code=404, detail=f"provider not found: {payload.providerId}")
  233. return ok(ModelProviderDto.from_entity(entity))
  234. @router.post("/providers/delete", response_model=ApiResponse[DeleteData])
  235. def delete_model_provider_contract(
  236. payload: ModelProviderDeleteRequestDto,
  237. service: ModelServiceDep) -> ApiResponse[DeleteData]:
  238. deleted = service.delete_provider(payload)
  239. return ok(DeleteData(deleted=deleted, providerId=payload.providerId))
  240. @router.post("/providers/test", response_model=ApiResponse[ModelProviderTestData])
  241. def test_model_provider_contract(
  242. payload: ModelProviderTestRequestDto,
  243. service: ModelServiceDep) -> ApiResponse[ModelProviderTestData]:
  244. try:
  245. result = service.test_provider(payload)
  246. except ModelProviderClientError as exc:
  247. raise HTTPException(status_code=502, detail=str(exc)) from exc
  248. if result is None:
  249. raise HTTPException(status_code=404, detail=f"provider not found: {payload.providerId}")
  250. return ok(result)
  251. @router.post("/providers/discover", response_model=ApiResponse[DiscoverModelsData])
  252. def discover_models_contract(
  253. payload: DiscoverModelsRequestDto,
  254. service: ModelServiceDep) -> ApiResponse[DiscoverModelsData]:
  255. try:
  256. return ok(service.discover_models(payload))
  257. except ModelProviderClientError as exc:
  258. raise HTTPException(status_code=502, detail=str(exc)) from exc