routes.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import json
  2. from datetime import datetime
  3. from typing import Annotated, TypeVar
  4. from core_domain import (
  5. ChatCompletionRequestContract,
  6. ChatCompletionResponseContract,
  7. EmbeddingRequestContract,
  8. EmbeddingResponseContract,
  9. ServiceHealth,
  10. )
  11. from fastapi import APIRouter, Depends, HTTPException
  12. from fastapi.responses import StreamingResponse
  13. from sqlalchemy import text
  14. from sqlalchemy.orm import Session
  15. from app.application.services import ModelGatewayApplicationService
  16. from app.bootstrap.settings import ModelGatewayServiceSettings
  17. from app.db.session import get_db
  18. from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
  19. from app.infrastructure.provider import ModelProviderClient, ModelProviderClientError
  20. from app.schemas.model import (
  21. ApiResponse,
  22. DeleteData,
  23. DiscoverModelsData,
  24. DiscoverModelsRequestDto,
  25. ModelCreateRequest,
  26. ModelCreateRequestDto,
  27. ModelDeleteRequestDto,
  28. ModelDto,
  29. ModelProviderCreateRequestDto,
  30. ModelProviderDeleteRequestDto,
  31. ModelProviderDto,
  32. ModelProviderTestData,
  33. ModelProviderTestRequestDto,
  34. ModelProviderUpdateRequestDto,
  35. ModelResponse,
  36. ModelStatusUpdateRequest,
  37. ModelTestData,
  38. ModelTestRequest,
  39. ModelTestRequestDto,
  40. ModelTestResponse,
  41. ModelUpdateRequest,
  42. ModelUpdateRequestDto,
  43. PageRequest,
  44. PageResult,
  45. )
  46. router = APIRouter()
  47. DbSession = Annotated[Session, Depends(get_db)]
  48. T = TypeVar("T")
  49. def get_model_gateway_settings() -> ModelGatewayServiceSettings:
  50. return ModelGatewayServiceSettings()
  51. def get_model_gateway_application_service(
  52. db: DbSession,
  53. settings: Annotated[
  54. ModelGatewayServiceSettings,
  55. Depends(get_model_gateway_settings),
  56. ]) -> ModelGatewayApplicationService:
  57. return ModelGatewayApplicationService(
  58. model_repository=ModelDefinitionRepository(db),
  59. provider_repository=ModelProviderDefinitionRepository(db),
  60. provider_client=ModelProviderClient(settings=settings),
  61. settings=settings)
  62. ModelServiceDep = Annotated[
  63. ModelGatewayApplicationService,
  64. Depends(get_model_gateway_application_service),
  65. ]
  66. def ok(data: T) -> ApiResponse[T]:
  67. return ApiResponse[T](
  68. data=data,
  69. requestId="",
  70. serverTime=datetime.utcnow())
  71. def json_dump(payload: dict[str, object]) -> str:
  72. return json.dumps(payload, ensure_ascii=False, default=str)
  73. @router.get("/health", response_model=ServiceHealth)
  74. def health_check(
  75. db: DbSession,
  76. settings: Annotated[
  77. ModelGatewayServiceSettings,
  78. Depends(get_model_gateway_settings),
  79. ]) -> ServiceHealth:
  80. db.execute(text("SELECT 1"))
  81. provider_status = "configured" if settings.provider_base_url else "missing"
  82. return ServiceHealth(service="model-gateway-service", status="ok", database=provider_status)
  83. @router.post("", response_model=ModelResponse)
  84. def create_model(
  85. payload: ModelCreateRequest,
  86. service: ModelServiceDep,
  87. ) -> ModelResponse:
  88. try:
  89. entity = service.create_model(payload)
  90. except ValueError as exc:
  91. raise HTTPException(status_code=422, detail=str(exc)) from exc
  92. return ModelResponse.from_entity(entity)
  93. @router.get("", response_model=list[ModelResponse])
  94. def list_models(
  95. service: ModelServiceDep,
  96. ) -> list[ModelResponse]:
  97. return [ModelResponse.from_entity(item) for item in service.list_models()]
  98. @router.patch("/{model_id}", response_model=ModelResponse)
  99. def update_model(
  100. model_id: str,
  101. payload: ModelUpdateRequest,
  102. service: ModelServiceDep,
  103. ) -> ModelResponse:
  104. try:
  105. entity = service.update_model(model_id=model_id, payload=payload)
  106. except ValueError as exc:
  107. raise HTTPException(status_code=422, detail=str(exc)) from exc
  108. if entity is None:
  109. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  110. return ModelResponse.from_entity(entity)
  111. @router.patch("/{model_id}/status", response_model=ModelResponse)
  112. def update_model_status(
  113. model_id: str,
  114. payload: ModelStatusUpdateRequest,
  115. service: ModelServiceDep,
  116. ) -> ModelResponse:
  117. entity = service.update_model_status(model_id=model_id, payload=payload)
  118. if entity is None:
  119. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  120. return ModelResponse.from_entity(entity)
  121. @router.delete("/{model_id}", status_code=204)
  122. def delete_model(
  123. model_id: str,
  124. service: ModelServiceDep,
  125. ) -> None:
  126. if not service.delete_model(model_id):
  127. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  128. @router.post("/{model_id}/test", response_model=ModelTestResponse)
  129. def test_model(
  130. model_id: str,
  131. payload: ModelTestRequest,
  132. service: ModelServiceDep,
  133. ) -> ModelTestResponse:
  134. try:
  135. result = service.test_model(model_id=model_id, payload=payload)
  136. except ModelProviderClientError as exc:
  137. raise HTTPException(status_code=502, detail=str(exc)) from exc
  138. if result is None:
  139. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  140. return result
  141. @router.post("/chat-completions", response_model=ChatCompletionResponseContract)
  142. def create_chat_completion(
  143. payload: ChatCompletionRequestContract,
  144. service: ModelServiceDep) -> ChatCompletionResponseContract:
  145. try:
  146. return service.create_chat_completion(payload)
  147. except ModelProviderClientError as exc:
  148. raise HTTPException(status_code=502, detail=str(exc)) from exc
  149. @router.post("/chat-completions/stream")
  150. def stream_chat_completion(
  151. payload: ChatCompletionRequestContract,
  152. service: ModelServiceDep) -> StreamingResponse:
  153. def events():
  154. try:
  155. for delta in service.stream_chat_completion(payload):
  156. yield f"event: delta\ndata: {json_dump({'delta': delta})}\n\n"
  157. yield "event: done\ndata: {}\n\n"
  158. except ModelProviderClientError as exc:
  159. yield f"event: error\ndata: {json_dump({'message': str(exc)})}\n\n"
  160. return StreamingResponse(
  161. events(),
  162. media_type="text/event-stream",
  163. headers={
  164. "Cache-Control": "no-cache",
  165. "X-Accel-Buffering": "no",
  166. })
  167. @router.post("/embeddings", response_model=EmbeddingResponseContract)
  168. def create_embedding(
  169. payload: EmbeddingRequestContract,
  170. service: ModelServiceDep) -> EmbeddingResponseContract:
  171. try:
  172. return service.create_embedding(payload)
  173. except ModelProviderClientError as exc:
  174. raise HTTPException(status_code=502, detail=str(exc)) from exc
  175. @router.post("/list", response_model=ApiResponse[PageResult[ModelDto]])
  176. def list_models_contract(
  177. payload: PageRequest,
  178. service: ModelServiceDep) -> ApiResponse[PageResult[ModelDto]]:
  179. keyword = (payload.keyword or "").lower().strip()
  180. items = [
  181. item
  182. for item in service.list_models()
  183. if not keyword
  184. or keyword in item.name.lower()
  185. or keyword in item.model_name.lower()
  186. or keyword in item.provider_type.lower()
  187. ]
  188. page_items = items[payload.offset:payload.offset + payload.pageSize]
  189. return ok(
  190. PageResult[ModelDto].from_items(
  191. items=[ModelDto.from_entity(item) for item in page_items],
  192. total=len(items),
  193. page=payload.page,
  194. page_size=payload.pageSize))
  195. @router.post("/create", response_model=ApiResponse[ModelDto])
  196. def create_model_contract(
  197. payload: ModelCreateRequestDto,
  198. service: ModelServiceDep) -> ApiResponse[ModelDto]:
  199. try:
  200. entity = service.create_model_from_contract(payload)
  201. except ValueError as exc:
  202. raise HTTPException(status_code=422, detail=str(exc)) from exc
  203. return ok(ModelDto.from_entity(entity))
  204. @router.post("/update", response_model=ApiResponse[ModelDto])
  205. def update_model_contract(
  206. payload: ModelUpdateRequestDto,
  207. service: ModelServiceDep) -> ApiResponse[ModelDto]:
  208. try:
  209. entity = service.update_model_from_contract(payload)
  210. except ValueError as exc:
  211. raise HTTPException(status_code=422, detail=str(exc)) from exc
  212. if entity is None:
  213. raise HTTPException(status_code=404, detail=f"model not found: {payload.modelId}")
  214. return ok(ModelDto.from_entity(entity))
  215. @router.post("/delete", response_model=ApiResponse[DeleteData])
  216. def delete_model_contract(
  217. payload: ModelDeleteRequestDto,
  218. service: ModelServiceDep) -> ApiResponse[DeleteData]:
  219. deleted = service.delete_model_from_contract(payload)
  220. return ok(DeleteData(deleted=deleted, modelId=payload.modelId))
  221. @router.post("/test", response_model=ApiResponse[ModelTestData])
  222. def test_model_contract(
  223. payload: ModelTestRequestDto,
  224. service: ModelServiceDep) -> ApiResponse[ModelTestData]:
  225. try:
  226. result = service.test_model_from_contract(payload)
  227. except ModelProviderClientError as exc:
  228. raise HTTPException(status_code=502, detail=str(exc)) from exc
  229. if result is None:
  230. raise HTTPException(status_code=404, detail=f"model not found: {payload.modelId}")
  231. return ok(result)
  232. @router.post("/providers/list", response_model=ApiResponse[PageResult[ModelProviderDto]])
  233. def list_model_providers_contract(
  234. payload: PageRequest,
  235. service: ModelServiceDep) -> ApiResponse[PageResult[ModelProviderDto]]:
  236. keyword = (payload.keyword or "").lower().strip()
  237. items = [
  238. item
  239. for item in service.list_providers()
  240. if not keyword
  241. or keyword in item.name.lower()
  242. or keyword in item.provider_type.lower()
  243. or keyword in item.base_url.lower()
  244. ]
  245. page_items = items[payload.offset:payload.offset + payload.pageSize]
  246. return ok(
  247. PageResult[ModelProviderDto].from_items(
  248. items=[ModelProviderDto.from_entity(item) for item in page_items],
  249. total=len(items),
  250. page=payload.page,
  251. page_size=payload.pageSize))
  252. @router.post("/providers/create", response_model=ApiResponse[ModelProviderDto])
  253. def create_model_provider_contract(
  254. payload: ModelProviderCreateRequestDto,
  255. service: ModelServiceDep) -> ApiResponse[ModelProviderDto]:
  256. entity = service.create_provider(payload)
  257. return ok(ModelProviderDto.from_entity(entity))
  258. @router.post("/providers/update", response_model=ApiResponse[ModelProviderDto])
  259. def update_model_provider_contract(
  260. payload: ModelProviderUpdateRequestDto,
  261. service: ModelServiceDep) -> ApiResponse[ModelProviderDto]:
  262. entity = service.update_provider(payload)
  263. if entity is None:
  264. raise HTTPException(status_code=404, detail=f"provider not found: {payload.providerId}")
  265. return ok(ModelProviderDto.from_entity(entity))
  266. @router.post("/providers/delete", response_model=ApiResponse[DeleteData])
  267. def delete_model_provider_contract(
  268. payload: ModelProviderDeleteRequestDto,
  269. service: ModelServiceDep) -> ApiResponse[DeleteData]:
  270. deleted = service.delete_provider(payload)
  271. return ok(DeleteData(deleted=deleted, providerId=payload.providerId))
  272. @router.post("/providers/test", response_model=ApiResponse[ModelProviderTestData])
  273. def test_model_provider_contract(
  274. payload: ModelProviderTestRequestDto,
  275. service: ModelServiceDep) -> ApiResponse[ModelProviderTestData]:
  276. try:
  277. result = service.test_provider(payload)
  278. except ModelProviderClientError as exc:
  279. raise HTTPException(status_code=502, detail=str(exc)) from exc
  280. if result is None:
  281. raise HTTPException(status_code=404, detail=f"provider not found: {payload.providerId}")
  282. return ok(result)
  283. @router.post("/providers/discover", response_model=ApiResponse[DiscoverModelsData])
  284. def discover_models_contract(
  285. payload: DiscoverModelsRequestDto,
  286. service: ModelServiceDep) -> ApiResponse[DiscoverModelsData]:
  287. try:
  288. return ok(service.discover_models(payload))
  289. except ModelProviderClientError as exc:
  290. raise HTTPException(status_code=502, detail=str(exc)) from exc