routes.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import json
  2. from datetime import datetime
  3. from typing import Annotated, TypeVar
  4. from core_domain import (
  5. ChatCompletionRequestContract,
  6. ChatCompletionResponseContract,
  7. ServiceHealth,
  8. )
  9. from fastapi import APIRouter, Depends, HTTPException
  10. from fastapi.responses import StreamingResponse
  11. from sqlalchemy import text
  12. from sqlalchemy.orm import Session
  13. from app.application.services import ModelGatewayApplicationService
  14. from app.bootstrap.settings import ModelGatewayServiceSettings
  15. from app.db.session import get_db
  16. from app.domain.repositories import ModelDefinitionRepository, ModelProviderDefinitionRepository
  17. from app.infrastructure.provider import ModelProviderClient, ModelProviderClientError
  18. from app.schemas.model import (
  19. ApiResponse,
  20. DeleteData,
  21. DiscoverModelsData,
  22. DiscoverModelsRequestDto,
  23. ModelCreateRequest,
  24. ModelCreateRequestDto,
  25. ModelDeleteRequestDto,
  26. ModelDto,
  27. ModelProviderCreateRequestDto,
  28. ModelProviderDeleteRequestDto,
  29. ModelProviderDto,
  30. ModelProviderTestData,
  31. ModelProviderTestRequestDto,
  32. ModelProviderUpdateRequestDto,
  33. ModelResponse,
  34. ModelStatusUpdateRequest,
  35. ModelTestData,
  36. ModelTestRequest,
  37. ModelTestRequestDto,
  38. ModelTestResponse,
  39. ModelUpdateRequest,
  40. ModelUpdateRequestDto,
  41. PageRequest,
  42. PageResult,
  43. )
  44. router = APIRouter()
  45. DbSession = Annotated[Session, Depends(get_db)]
  46. T = TypeVar("T")
  47. def get_model_gateway_settings() -> ModelGatewayServiceSettings:
  48. return ModelGatewayServiceSettings()
  49. def get_model_gateway_application_service(
  50. db: DbSession,
  51. settings: Annotated[
  52. ModelGatewayServiceSettings,
  53. Depends(get_model_gateway_settings),
  54. ]) -> ModelGatewayApplicationService:
  55. return ModelGatewayApplicationService(
  56. model_repository=ModelDefinitionRepository(db),
  57. provider_repository=ModelProviderDefinitionRepository(db),
  58. provider_client=ModelProviderClient(settings=settings),
  59. settings=settings)
  60. ModelServiceDep = Annotated[
  61. ModelGatewayApplicationService,
  62. Depends(get_model_gateway_application_service),
  63. ]
  64. def ok(data: T) -> ApiResponse[T]:
  65. return ApiResponse[T](
  66. data=data,
  67. requestId="",
  68. serverTime=datetime.utcnow())
  69. def json_dump(payload: dict[str, object]) -> str:
  70. return json.dumps(payload, ensure_ascii=False, default=str)
  71. @router.get("/health", response_model=ServiceHealth)
  72. def health_check(
  73. db: DbSession,
  74. settings: Annotated[
  75. ModelGatewayServiceSettings,
  76. Depends(get_model_gateway_settings),
  77. ]) -> ServiceHealth:
  78. db.execute(text("SELECT 1"))
  79. provider_status = "configured" if settings.provider_base_url else "missing"
  80. return ServiceHealth(service="model-gateway-service", status="ok", database=provider_status)
  81. @router.post("", response_model=ModelResponse)
  82. def create_model(
  83. payload: ModelCreateRequest,
  84. service: ModelServiceDep,
  85. ) -> ModelResponse:
  86. try:
  87. entity = service.create_model(payload)
  88. except ValueError as exc:
  89. raise HTTPException(status_code=422, detail=str(exc)) from exc
  90. return ModelResponse.from_entity(entity)
  91. @router.get("", response_model=list[ModelResponse])
  92. def list_models(
  93. service: ModelServiceDep,
  94. ) -> list[ModelResponse]:
  95. return [ModelResponse.from_entity(item) for item in service.list_models()]
  96. @router.patch("/{model_id}", response_model=ModelResponse)
  97. def update_model(
  98. model_id: str,
  99. payload: ModelUpdateRequest,
  100. service: ModelServiceDep,
  101. ) -> ModelResponse:
  102. try:
  103. entity = service.update_model(model_id=model_id, payload=payload)
  104. except ValueError as exc:
  105. raise HTTPException(status_code=422, detail=str(exc)) from exc
  106. if entity is None:
  107. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  108. return ModelResponse.from_entity(entity)
  109. @router.patch("/{model_id}/status", response_model=ModelResponse)
  110. def update_model_status(
  111. model_id: str,
  112. payload: ModelStatusUpdateRequest,
  113. service: ModelServiceDep,
  114. ) -> ModelResponse:
  115. entity = service.update_model_status(model_id=model_id, payload=payload)
  116. if entity is None:
  117. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  118. return ModelResponse.from_entity(entity)
  119. @router.delete("/{model_id}", status_code=204)
  120. def delete_model(
  121. model_id: str,
  122. service: ModelServiceDep,
  123. ) -> None:
  124. if not service.delete_model(model_id):
  125. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  126. @router.post("/{model_id}/test", response_model=ModelTestResponse)
  127. def test_model(
  128. model_id: str,
  129. payload: ModelTestRequest,
  130. service: ModelServiceDep,
  131. ) -> ModelTestResponse:
  132. try:
  133. result = service.test_model(model_id=model_id, payload=payload)
  134. except ModelProviderClientError as exc:
  135. raise HTTPException(status_code=502, detail=str(exc)) from exc
  136. if result is None:
  137. raise HTTPException(status_code=404, detail=f"model not found: {model_id}")
  138. return result
  139. @router.post("/chat-completions", response_model=ChatCompletionResponseContract)
  140. def create_chat_completion(
  141. payload: ChatCompletionRequestContract,
  142. service: ModelServiceDep) -> ChatCompletionResponseContract:
  143. try:
  144. return service.create_chat_completion(payload)
  145. except ModelProviderClientError as exc:
  146. raise HTTPException(status_code=502, detail=str(exc)) from exc
  147. @router.post("/chat-completions/stream")
  148. def stream_chat_completion(
  149. payload: ChatCompletionRequestContract,
  150. service: ModelServiceDep) -> StreamingResponse:
  151. def events():
  152. try:
  153. for delta in service.stream_chat_completion(payload):
  154. yield f"event: delta\ndata: {json_dump({'delta': delta})}\n\n"
  155. yield "event: done\ndata: {}\n\n"
  156. except ModelProviderClientError as exc:
  157. yield f"event: error\ndata: {json_dump({'message': str(exc)})}\n\n"
  158. return StreamingResponse(
  159. events(),
  160. media_type="text/event-stream",
  161. headers={
  162. "Cache-Control": "no-cache",
  163. "X-Accel-Buffering": "no",
  164. })
  165. @router.post("/list", response_model=ApiResponse[PageResult[ModelDto]])
  166. def list_models_contract(
  167. payload: PageRequest,
  168. service: ModelServiceDep) -> ApiResponse[PageResult[ModelDto]]:
  169. keyword = (payload.keyword or "").lower().strip()
  170. items = [
  171. item
  172. for item in service.list_models()
  173. if not keyword
  174. or keyword in item.name.lower()
  175. or keyword in item.model_name.lower()
  176. or keyword in item.provider_type.lower()
  177. ]
  178. page_items = items[payload.offset:payload.offset + payload.pageSize]
  179. return ok(
  180. PageResult[ModelDto].from_items(
  181. items=[ModelDto.from_entity(item) for item in page_items],
  182. total=len(items),
  183. page=payload.page,
  184. page_size=payload.pageSize))
  185. @router.post("/create", response_model=ApiResponse[ModelDto])
  186. def create_model_contract(
  187. payload: ModelCreateRequestDto,
  188. service: ModelServiceDep) -> ApiResponse[ModelDto]:
  189. try:
  190. entity = service.create_model_from_contract(payload)
  191. except ValueError as exc:
  192. raise HTTPException(status_code=422, detail=str(exc)) from exc
  193. return ok(ModelDto.from_entity(entity))
  194. @router.post("/update", response_model=ApiResponse[ModelDto])
  195. def update_model_contract(
  196. payload: ModelUpdateRequestDto,
  197. service: ModelServiceDep) -> ApiResponse[ModelDto]:
  198. try:
  199. entity = service.update_model_from_contract(payload)
  200. except ValueError as exc:
  201. raise HTTPException(status_code=422, detail=str(exc)) from exc
  202. if entity is None:
  203. raise HTTPException(status_code=404, detail=f"model not found: {payload.modelId}")
  204. return ok(ModelDto.from_entity(entity))
  205. @router.post("/delete", response_model=ApiResponse[DeleteData])
  206. def delete_model_contract(
  207. payload: ModelDeleteRequestDto,
  208. service: ModelServiceDep) -> ApiResponse[DeleteData]:
  209. deleted = service.delete_model_from_contract(payload)
  210. return ok(DeleteData(deleted=deleted, modelId=payload.modelId))
  211. @router.post("/test", response_model=ApiResponse[ModelTestData])
  212. def test_model_contract(
  213. payload: ModelTestRequestDto,
  214. service: ModelServiceDep) -> ApiResponse[ModelTestData]:
  215. try:
  216. result = service.test_model_from_contract(payload)
  217. except ModelProviderClientError as exc:
  218. raise HTTPException(status_code=502, detail=str(exc)) from exc
  219. if result is None:
  220. raise HTTPException(status_code=404, detail=f"model not found: {payload.modelId}")
  221. return ok(result)
  222. @router.post("/providers/list", response_model=ApiResponse[PageResult[ModelProviderDto]])
  223. def list_model_providers_contract(
  224. payload: PageRequest,
  225. service: ModelServiceDep) -> ApiResponse[PageResult[ModelProviderDto]]:
  226. keyword = (payload.keyword or "").lower().strip()
  227. items = [
  228. item
  229. for item in service.list_providers()
  230. if not keyword
  231. or keyword in item.name.lower()
  232. or keyword in item.provider_type.lower()
  233. or keyword in item.base_url.lower()
  234. ]
  235. page_items = items[payload.offset:payload.offset + payload.pageSize]
  236. return ok(
  237. PageResult[ModelProviderDto].from_items(
  238. items=[ModelProviderDto.from_entity(item) for item in page_items],
  239. total=len(items),
  240. page=payload.page,
  241. page_size=payload.pageSize))
  242. @router.post("/providers/create", response_model=ApiResponse[ModelProviderDto])
  243. def create_model_provider_contract(
  244. payload: ModelProviderCreateRequestDto,
  245. service: ModelServiceDep) -> ApiResponse[ModelProviderDto]:
  246. entity = service.create_provider(payload)
  247. return ok(ModelProviderDto.from_entity(entity))
  248. @router.post("/providers/update", response_model=ApiResponse[ModelProviderDto])
  249. def update_model_provider_contract(
  250. payload: ModelProviderUpdateRequestDto,
  251. service: ModelServiceDep) -> ApiResponse[ModelProviderDto]:
  252. entity = service.update_provider(payload)
  253. if entity is None:
  254. raise HTTPException(status_code=404, detail=f"provider not found: {payload.providerId}")
  255. return ok(ModelProviderDto.from_entity(entity))
  256. @router.post("/providers/delete", response_model=ApiResponse[DeleteData])
  257. def delete_model_provider_contract(
  258. payload: ModelProviderDeleteRequestDto,
  259. service: ModelServiceDep) -> ApiResponse[DeleteData]:
  260. deleted = service.delete_provider(payload)
  261. return ok(DeleteData(deleted=deleted, providerId=payload.providerId))
  262. @router.post("/providers/test", response_model=ApiResponse[ModelProviderTestData])
  263. def test_model_provider_contract(
  264. payload: ModelProviderTestRequestDto,
  265. service: ModelServiceDep) -> ApiResponse[ModelProviderTestData]:
  266. try:
  267. result = service.test_provider(payload)
  268. except ModelProviderClientError as exc:
  269. raise HTTPException(status_code=502, detail=str(exc)) from exc
  270. if result is None:
  271. raise HTTPException(status_code=404, detail=f"provider not found: {payload.providerId}")
  272. return ok(result)
  273. @router.post("/providers/discover", response_model=ApiResponse[DiscoverModelsData])
  274. def discover_models_contract(
  275. payload: DiscoverModelsRequestDto,
  276. service: ModelServiceDep) -> ApiResponse[DiscoverModelsData]:
  277. try:
  278. return ok(service.discover_models(payload))
  279. except ModelProviderClientError as exc:
  280. raise HTTPException(status_code=502, detail=str(exc)) from exc