routes.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import json
  2. from datetime import datetime
  3. from typing import Annotated, TypeVar
  4. from core_domain import (
  5. ChatCompletionRequestContract,
  6. ChatCompletionResponseContract,
  7. EmbeddingRequestContract,
  8. EmbeddingResponseContract,
  9. ServiceHealth,
  10. )
  11. from core_shared import error_detail
  12. from fastapi import APIRouter, Depends, HTTPException
  13. from fastapi.responses import StreamingResponse
  14. from sqlalchemy import text
  15. from sqlalchemy.orm import Session
  16. from app.application.services import ModelGatewayApplicationService
  17. from app.bootstrap.settings import ModelGatewayServiceSettings
  18. from app.db.session import get_db
  19. from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
  20. from app.infrastructure.provider import ModelProviderClient, ModelProviderClientError
  21. from app.schemas.model import (
  22. ApiResponse,
  23. DeleteData,
  24. DiscoverModelsData,
  25. DiscoverModelsRequestDto,
  26. ModelCreateRequest,
  27. ModelCreateRequestDto,
  28. ModelDeleteRequestDto,
  29. ModelDto,
  30. ModelProviderCreateRequestDto,
  31. ModelProviderDeleteRequestDto,
  32. ModelProviderDto,
  33. ModelProviderTestData,
  34. ModelProviderTestRequestDto,
  35. ModelProviderUpdateRequestDto,
  36. ModelResponse,
  37. ModelStatusUpdateRequest,
  38. ModelTestData,
  39. ModelTestRequest,
  40. ModelTestRequestDto,
  41. ModelTestResponse,
  42. ModelUpdateRequest,
  43. ModelUpdateRequestDto,
  44. PageRequest,
  45. PageResult,
  46. )
  47. router = APIRouter()
  48. DbSession = Annotated[Session, Depends(get_db)]
  49. T = TypeVar("T")
  50. def get_model_gateway_settings() -> ModelGatewayServiceSettings:
  51. return ModelGatewayServiceSettings()
  52. def get_model_gateway_application_service(
  53. db: DbSession,
  54. settings: Annotated[
  55. ModelGatewayServiceSettings,
  56. Depends(get_model_gateway_settings),
  57. ]) -> ModelGatewayApplicationService:
  58. return ModelGatewayApplicationService(
  59. model_repository=ModelDefinitionRepository(db),
  60. provider_repository=ModelProviderDefinitionRepository(db),
  61. provider_client=ModelProviderClient(settings=settings),
  62. settings=settings)
  63. ModelServiceDep = Annotated[
  64. ModelGatewayApplicationService,
  65. Depends(get_model_gateway_application_service),
  66. ]
  67. def ok(data: T) -> ApiResponse[T]:
  68. return ApiResponse[T](
  69. data=data,
  70. requestId="",
  71. serverTime=datetime.utcnow())
  72. def json_dump(payload: dict[str, object]) -> str:
  73. return json.dumps(payload, ensure_ascii=False, default=str)
  74. @router.get("/health", response_model=ServiceHealth)
  75. def health_check(
  76. db: DbSession,
  77. settings: Annotated[
  78. ModelGatewayServiceSettings,
  79. Depends(get_model_gateway_settings),
  80. ]) -> ServiceHealth:
  81. db.execute(text("SELECT 1"))
  82. provider_status = "configured" if settings.provider_base_url else "missing"
  83. return ServiceHealth(service="model-gateway-service", status="ok", database=provider_status)
  84. @router.post("", response_model=ModelResponse)
  85. def create_model(
  86. payload: ModelCreateRequest,
  87. service: ModelServiceDep,
  88. ) -> ModelResponse:
  89. try:
  90. entity = service.create_model(payload)
  91. except ValueError as exc:
  92. raise HTTPException(status_code=422, detail=error_detail("error.validation", message=str(exc))) from exc
  93. return ModelResponse.from_entity(entity)
  94. @router.get("", response_model=list[ModelResponse])
  95. def list_models(
  96. service: ModelServiceDep,
  97. ) -> list[ModelResponse]:
  98. return [ModelResponse.from_entity(item) for item in service.list_models()]
  99. @router.patch("/{model_id}", response_model=ModelResponse)
  100. def update_model(
  101. model_id: str,
  102. payload: ModelUpdateRequest,
  103. service: ModelServiceDep,
  104. ) -> ModelResponse:
  105. try:
  106. entity = service.update_model(model_id=model_id, payload=payload)
  107. except ValueError as exc:
  108. raise HTTPException(status_code=422, detail=error_detail("error.validation", message=str(exc))) from exc
  109. if entity is None:
  110. raise HTTPException(status_code=404, detail=error_detail("error.model.not_found", id=model_id))
  111. return ModelResponse.from_entity(entity)
  112. @router.patch("/{model_id}/status", response_model=ModelResponse)
  113. def update_model_status(
  114. model_id: str,
  115. payload: ModelStatusUpdateRequest,
  116. service: ModelServiceDep,
  117. ) -> ModelResponse:
  118. entity = service.update_model_status(model_id=model_id, payload=payload)
  119. if entity is None:
  120. raise HTTPException(status_code=404, detail=error_detail("error.model.not_found", id=model_id))
  121. return ModelResponse.from_entity(entity)
  122. @router.delete("/{model_id}", status_code=204)
  123. def delete_model(
  124. model_id: str,
  125. service: ModelServiceDep,
  126. ) -> None:
  127. if not service.delete_model(model_id):
  128. raise HTTPException(status_code=404, detail=error_detail("error.model.not_found", id=model_id))
  129. @router.post("/{model_id}/test", response_model=ModelTestResponse)
  130. def test_model(
  131. model_id: str,
  132. payload: ModelTestRequest,
  133. service: ModelServiceDep,
  134. ) -> ModelTestResponse:
  135. try:
  136. result = service.test_model(model_id=model_id, payload=payload)
  137. except ModelProviderClientError as exc:
  138. raise HTTPException(status_code=502, detail=error_detail("error.downstream.request_failed", service="model-gateway", error=str(exc))) from exc
  139. if result is None:
  140. raise HTTPException(status_code=404, detail=error_detail("error.model.not_found", id=model_id))
  141. return result
  142. @router.post("/chat-completions", response_model=ChatCompletionResponseContract)
  143. def create_chat_completion(
  144. payload: ChatCompletionRequestContract,
  145. service: ModelServiceDep) -> ChatCompletionResponseContract:
  146. try:
  147. return service.create_chat_completion(payload)
  148. except ModelProviderClientError as exc:
  149. raise HTTPException(status_code=502, detail=error_detail("error.downstream.request_failed", service="model-gateway", error=str(exc))) from exc
  150. @router.post("/chat-completions/stream")
  151. def stream_chat_completion(
  152. payload: ChatCompletionRequestContract,
  153. service: ModelServiceDep) -> StreamingResponse:
  154. def events():
  155. try:
  156. for delta in service.stream_chat_completion(payload):
  157. yield f"event: delta\ndata: {json_dump({'delta': delta})}\n\n"
  158. yield "event: done\ndata: {}\n\n"
  159. except ModelProviderClientError as exc:
  160. yield f"event: error\ndata: {json_dump({'message': str(exc)})}\n\n"
  161. return StreamingResponse(
  162. events(),
  163. media_type="text/event-stream",
  164. headers={
  165. "Cache-Control": "no-cache",
  166. "X-Accel-Buffering": "no",
  167. })
  168. @router.post("/embeddings", response_model=EmbeddingResponseContract)
  169. def create_embedding(
  170. payload: EmbeddingRequestContract,
  171. service: ModelServiceDep) -> EmbeddingResponseContract:
  172. try:
  173. return service.create_embedding(payload)
  174. except ModelProviderClientError as exc:
  175. raise HTTPException(status_code=502, detail=error_detail("error.downstream.request_failed", service="model-gateway", error=str(exc))) from exc
  176. @router.post("/list", response_model=ApiResponse[PageResult[ModelDto]])
  177. def list_models_contract(
  178. payload: PageRequest,
  179. service: ModelServiceDep) -> ApiResponse[PageResult[ModelDto]]:
  180. keyword = (payload.keyword or "").lower().strip()
  181. items = [
  182. item
  183. for item in service.list_models()
  184. if not keyword
  185. or keyword in item.name.lower()
  186. or keyword in item.model_name.lower()
  187. or keyword in item.provider_type.lower()
  188. ]
  189. page_items = items[payload.offset:payload.offset + payload.pageSize]
  190. return ok(
  191. PageResult[ModelDto].from_items(
  192. items=[ModelDto.from_entity(item) for item in page_items],
  193. total=len(items),
  194. page=payload.page,
  195. page_size=payload.pageSize))
  196. @router.post("/create", response_model=ApiResponse[ModelDto])
  197. def create_model_contract(
  198. payload: ModelCreateRequestDto,
  199. service: ModelServiceDep) -> ApiResponse[ModelDto]:
  200. try:
  201. entity = service.create_model_from_contract(payload)
  202. except ValueError as exc:
  203. raise HTTPException(status_code=422, detail=error_detail("error.validation", message=str(exc))) from exc
  204. return ok(ModelDto.from_entity(entity))
  205. @router.post("/update", response_model=ApiResponse[ModelDto])
  206. def update_model_contract(
  207. payload: ModelUpdateRequestDto,
  208. service: ModelServiceDep) -> ApiResponse[ModelDto]:
  209. try:
  210. entity = service.update_model_from_contract(payload)
  211. except ValueError as exc:
  212. raise HTTPException(status_code=422, detail=error_detail("error.validation", message=str(exc))) from exc
  213. if entity is None:
  214. raise HTTPException(status_code=404, detail=error_detail("error.model.not_found", id=payload.modelId))
  215. return ok(ModelDto.from_entity(entity))
  216. @router.post("/delete", response_model=ApiResponse[DeleteData])
  217. def delete_model_contract(
  218. payload: ModelDeleteRequestDto,
  219. service: ModelServiceDep) -> ApiResponse[DeleteData]:
  220. deleted = service.delete_model_from_contract(payload)
  221. return ok(DeleteData(deleted=deleted, modelId=payload.modelId))
  222. @router.post("/test", response_model=ApiResponse[ModelTestData])
  223. def test_model_contract(
  224. payload: ModelTestRequestDto,
  225. service: ModelServiceDep) -> ApiResponse[ModelTestData]:
  226. try:
  227. result = service.test_model_from_contract(payload)
  228. except ModelProviderClientError as exc:
  229. raise HTTPException(status_code=502, detail=error_detail("error.downstream.request_failed", service="model-gateway", error=str(exc))) from exc
  230. if result is None:
  231. raise HTTPException(status_code=404, detail=error_detail("error.model.not_found", id=payload.modelId))
  232. return ok(result)
  233. @router.post("/providers/list", response_model=ApiResponse[PageResult[ModelProviderDto]])
  234. def list_model_providers_contract(
  235. payload: PageRequest,
  236. service: ModelServiceDep) -> ApiResponse[PageResult[ModelProviderDto]]:
  237. keyword = (payload.keyword or "").lower().strip()
  238. items = [
  239. item
  240. for item in service.list_providers()
  241. if not keyword
  242. or keyword in item.name.lower()
  243. or keyword in item.provider_type.lower()
  244. or keyword in item.base_url.lower()
  245. ]
  246. page_items = items[payload.offset:payload.offset + payload.pageSize]
  247. return ok(
  248. PageResult[ModelProviderDto].from_items(
  249. items=[ModelProviderDto.from_entity(item) for item in page_items],
  250. total=len(items),
  251. page=payload.page,
  252. page_size=payload.pageSize))
  253. @router.post("/providers/create", response_model=ApiResponse[ModelProviderDto])
  254. def create_model_provider_contract(
  255. payload: ModelProviderCreateRequestDto,
  256. service: ModelServiceDep) -> ApiResponse[ModelProviderDto]:
  257. entity = service.create_provider(payload)
  258. return ok(ModelProviderDto.from_entity(entity))
  259. @router.post("/providers/update", response_model=ApiResponse[ModelProviderDto])
  260. def update_model_provider_contract(
  261. payload: ModelProviderUpdateRequestDto,
  262. service: ModelServiceDep) -> ApiResponse[ModelProviderDto]:
  263. entity = service.update_provider(payload)
  264. if entity is None:
  265. raise HTTPException(status_code=404, detail=error_detail("error.provider.not_found", id=payload.providerId))
  266. return ok(ModelProviderDto.from_entity(entity))
  267. @router.post("/providers/delete", response_model=ApiResponse[DeleteData])
  268. def delete_model_provider_contract(
  269. payload: ModelProviderDeleteRequestDto,
  270. service: ModelServiceDep) -> ApiResponse[DeleteData]:
  271. deleted = service.delete_provider(payload)
  272. return ok(DeleteData(deleted=deleted, providerId=payload.providerId))
  273. @router.post("/providers/test", response_model=ApiResponse[ModelProviderTestData])
  274. def test_model_provider_contract(
  275. payload: ModelProviderTestRequestDto,
  276. service: ModelServiceDep) -> ApiResponse[ModelProviderTestData]:
  277. try:
  278. result = service.test_provider(payload)
  279. except ModelProviderClientError as exc:
  280. raise HTTPException(status_code=502, detail=error_detail("error.downstream.request_failed", service="model-gateway", error=str(exc))) from exc
  281. if result is None:
  282. raise HTTPException(status_code=404, detail=error_detail("error.provider.not_found", id=payload.providerId))
  283. return ok(result)
  284. @router.post("/providers/discover", response_model=ApiResponse[DiscoverModelsData])
  285. def discover_models_contract(
  286. payload: DiscoverModelsRequestDto,
  287. service: ModelServiceDep) -> ApiResponse[DiscoverModelsData]:
  288. try:
  289. return ok(service.discover_models(payload))
  290. except ModelProviderClientError as exc:
  291. raise HTTPException(status_code=502, detail=error_detail("error.downstream.request_failed", service="model-gateway", error=str(exc))) from exc