model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. from datetime import datetime
  2. from typing import TYPE_CHECKING, Generic, Literal, TypeVar
  3. from core_domain import ChatCompletionResponseContract
  4. from core_shared import JSONValue
  5. from pydantic import BaseModel, Field, HttpUrl
  6. if TYPE_CHECKING:
  7. from app.db.models import ModelDefinition, ModelProviderDefinition
  8. ModelStatus = Literal["active", "disabled"]
  9. T = TypeVar("T")
  10. class ApiErrorResponse(BaseModel):
  11. errorType: str
  12. message: str
  13. details: dict[str, JSONValue] = Field(default_factory=dict)
  14. class ApiResponse(BaseModel, Generic[T]):
  15. success: bool = True
  16. data: T | None = None
  17. error: ApiErrorResponse | None = None
  18. requestId: str
  19. serverTime: datetime
  20. class PageRequest(BaseModel):
  21. page: int = Field(default=1, ge=1)
  22. pageSize: int = Field(default=20, ge=1, le=200)
  23. keyword: str | None = None
  24. @property
  25. def offset(self) -> int:
  26. return (self.page - 1) * self.pageSize
  27. class PageResult(BaseModel, Generic[T]):
  28. items: list[T]
  29. total: int
  30. page: int
  31. pageSize: int
  32. hasMore: bool
  33. @classmethod
  34. def from_items(
  35. cls,
  36. *,
  37. items: list[T],
  38. total: int,
  39. page: int,
  40. page_size: int) -> "PageResult[T]":
  41. return cls(
  42. items=items,
  43. total=total,
  44. page=page,
  45. pageSize=page_size,
  46. hasMore=page * page_size < total)
  47. class ModelCreateRequest(BaseModel):
  48. code: str | None = Field(default=None, min_length=1, max_length=64)
  49. name: str = Field(min_length=1, max_length=128)
  50. provider_id: str | None = None
  51. provider_type: str = "openai_compatible"
  52. provider_base_url: HttpUrl | str
  53. provider_api_key: str | None = None
  54. model_name: str = Field(min_length=1, max_length=128)
  55. status: ModelStatus = "active"
  56. description: str | None = None
  57. capabilities_json: list[str] = Field(default_factory=lambda: ["chat"])
  58. context_window: int | None = Field(default=None, ge=1)
  59. max_output_tokens: int | None = Field(default=None, ge=1)
  60. default_temperature: float | None = Field(default=None, ge=0, le=2)
  61. timeout_seconds: float = Field(default=60.0, ge=1, le=300)
  62. metadata_json: dict[str, JSONValue] = Field(default_factory=dict)
  63. class ModelUpdateRequest(BaseModel):
  64. code: str | None = Field(default=None, min_length=1, max_length=64)
  65. name: str | None = Field(default=None, min_length=1, max_length=128)
  66. provider_id: str | None = None
  67. provider_type: str | None = None
  68. provider_base_url: HttpUrl | str | None = None
  69. provider_api_key: str | None = None
  70. model_name: str | None = Field(default=None, min_length=1, max_length=128)
  71. status: ModelStatus | None = None
  72. description: str | None = None
  73. capabilities_json: list[str] | None = None
  74. context_window: int | None = Field(default=None, ge=1)
  75. max_output_tokens: int | None = Field(default=None, ge=1)
  76. default_temperature: float | None = Field(default=None, ge=0, le=2)
  77. timeout_seconds: float | None = Field(default=None, ge=1, le=300)
  78. metadata_json: dict[str, JSONValue] | None = None
  79. class ModelStatusUpdateRequest(BaseModel):
  80. status: ModelStatus
  81. class ModelResponse(BaseModel):
  82. id: str
  83. code: str
  84. name: str
  85. provider_id: str | None = None
  86. provider_type: str
  87. provider_base_url: str
  88. has_provider_api_key: bool
  89. model_name: str
  90. status: ModelStatus
  91. description: str | None = None
  92. capabilities_json: list[str] = Field(default_factory=list)
  93. context_window: int | None = None
  94. max_output_tokens: int | None = None
  95. default_temperature: float | None = None
  96. timeout_seconds: float
  97. metadata_json: dict[str, JSONValue] | None = None
  98. created_time: datetime
  99. updated_time: datetime
  100. @classmethod
  101. def from_entity(cls, entity: "ModelDefinition") -> "ModelResponse":
  102. return cls(
  103. id=entity.id,
  104. code=entity.code,
  105. name=entity.name,
  106. provider_id=entity.provider_id,
  107. provider_type=entity.provider_type,
  108. provider_base_url=entity.provider_base_url,
  109. has_provider_api_key=bool(entity.provider_api_key),
  110. model_name=entity.model_name,
  111. status=entity.status,
  112. description=entity.description,
  113. capabilities_json=list(entity.capabilities_json or []),
  114. context_window=entity.context_window,
  115. max_output_tokens=entity.max_output_tokens,
  116. default_temperature=entity.default_temperature,
  117. timeout_seconds=entity.timeout_seconds,
  118. metadata_json=entity.metadata_json,
  119. created_time=entity.created_time,
  120. updated_time=entity.updated_time,
  121. )
  122. class ModelDto(BaseModel):
  123. id: str
  124. name: str
  125. providerId: str | None = None
  126. providerType: str
  127. providerBaseUrl: str
  128. hasProviderApiKey: bool
  129. modelName: str
  130. description: str | None = None
  131. capabilities: list[str] = Field(default_factory=list)
  132. contextWindow: int | None = None
  133. maxOutputTokens: int | None = None
  134. defaultTemperature: float | None = None
  135. timeoutSeconds: float
  136. metadata: dict[str, JSONValue] | None = None
  137. createdTime: datetime
  138. updatedTime: datetime
  139. @classmethod
  140. def from_entity(cls, entity: "ModelDefinition") -> "ModelDto":
  141. return cls(
  142. id=entity.id,
  143. name=entity.name,
  144. providerId=entity.provider_id,
  145. providerType=entity.provider_type,
  146. providerBaseUrl=entity.provider_base_url,
  147. hasProviderApiKey=bool(entity.provider_api_key),
  148. modelName=entity.model_name,
  149. description=entity.description,
  150. capabilities=list(entity.capabilities_json or []),
  151. contextWindow=entity.context_window,
  152. maxOutputTokens=entity.max_output_tokens,
  153. defaultTemperature=entity.default_temperature,
  154. timeoutSeconds=entity.timeout_seconds,
  155. metadata=entity.metadata_json,
  156. createdTime=entity.created_time,
  157. updatedTime=entity.updated_time)
  158. class ModelCreateRequestDto(BaseModel):
  159. name: str = Field(min_length=1, max_length=128)
  160. providerId: str | None = None
  161. providerType: str = "openai_compatible"
  162. providerBaseUrl: HttpUrl | str | None = None
  163. providerApiKey: str | None = None
  164. modelName: str = Field(min_length=1, max_length=128)
  165. description: str | None = None
  166. capabilities: list[str] = Field(default_factory=lambda: ["chat"])
  167. contextWindow: int | None = Field(default=None, ge=1)
  168. maxOutputTokens: int | None = Field(default=None, ge=1)
  169. defaultTemperature: float | None = Field(default=None, ge=0, le=2)
  170. timeoutSeconds: float = Field(default=60.0, ge=1, le=300)
  171. metadata: dict[str, JSONValue] = Field(default_factory=dict)
  172. class ModelUpdateRequestDto(BaseModel):
  173. modelId: str
  174. name: str | None = Field(default=None, min_length=1, max_length=128)
  175. providerId: str | None = None
  176. providerType: str | None = None
  177. providerBaseUrl: HttpUrl | str | None = None
  178. providerApiKey: str | None = None
  179. modelName: str | None = Field(default=None, min_length=1, max_length=128)
  180. description: str | None = None
  181. capabilities: list[str] | None = None
  182. contextWindow: int | None = Field(default=None, ge=1)
  183. maxOutputTokens: int | None = Field(default=None, ge=1)
  184. defaultTemperature: float | None = Field(default=None, ge=0, le=2)
  185. timeoutSeconds: float | None = Field(default=None, ge=1, le=300)
  186. metadata: dict[str, JSONValue] | None = None
  187. class ModelDeleteRequestDto(BaseModel):
  188. modelId: str
  189. class DeleteData(BaseModel):
  190. deleted: bool
  191. modelId: str | None = None
  192. providerId: str | None = None
  193. class ModelTestRequest(BaseModel):
  194. prompt: str = Field(default="Reply with a short readiness check.", min_length=1)
  195. system_prompt: str | None = "You are a concise model connectivity checker."
  196. temperature: float | None = Field(default=None, ge=0, le=2)
  197. max_tokens: int | None = Field(default=128, ge=1)
  198. class ModelTestResponse(BaseModel):
  199. model: ModelResponse
  200. response: ChatCompletionResponseContract
  201. class ModelTestRequestDto(BaseModel):
  202. modelId: str
  203. prompt: str = Field(default="Reply with a short readiness check.", min_length=1)
  204. systemPrompt: str | None = "You are a concise model connectivity checker."
  205. temperature: float | None = Field(default=None, ge=0, le=2)
  206. maxTokens: int | None = Field(default=128, ge=1)
  207. class ModelTestData(BaseModel):
  208. model: ModelDto
  209. response: ChatCompletionResponseContract
  210. class ModelItemDto(BaseModel):
  211. modelId: str
  212. displayName: str
  213. modelType: str
  214. ownedBy: str | None = None
  215. contextWindow: int | None = None
  216. class ModelProviderDto(BaseModel):
  217. id: str
  218. name: str
  219. providerType: str
  220. baseUrl: str
  221. apiKeyRef: str
  222. models: list[ModelItemDto]
  223. defaultModel: str | None = None
  224. extraConfig: dict[str, JSONValue] = Field(default_factory=dict)
  225. createdTime: datetime
  226. updatedTime: datetime
  227. @classmethod
  228. def from_entity(cls, entity: "ModelProviderDefinition") -> "ModelProviderDto":
  229. return cls(
  230. id=entity.id,
  231. name=entity.name,
  232. providerType=entity.provider_type,
  233. baseUrl=entity.base_url,
  234. apiKeyRef=_mask_api_key(entity.api_key),
  235. models=[
  236. ModelItemDto(**_to_camel_model_item(item))
  237. for item in entity.models_json or []
  238. ],
  239. defaultModel=entity.default_model,
  240. extraConfig=entity.extra_config_json or {},
  241. createdTime=entity.created_time,
  242. updatedTime=entity.updated_time)
  243. class ModelProviderCreateRequestDto(BaseModel):
  244. name: str = Field(min_length=1, max_length=128)
  245. providerType: str = "openai_compatible"
  246. baseUrl: HttpUrl | str
  247. apiKey: str | None = None
  248. models: list[ModelItemDto] = Field(default_factory=list)
  249. defaultModel: str | None = None
  250. extraConfig: dict[str, JSONValue] = Field(default_factory=dict)
  251. class ModelProviderUpdateRequestDto(BaseModel):
  252. providerId: str
  253. name: str | None = Field(default=None, min_length=1, max_length=128)
  254. baseUrl: HttpUrl | str | None = None
  255. apiKey: str | None = None
  256. models: list[ModelItemDto] | None = None
  257. defaultModel: str | None = None
  258. extraConfig: dict[str, JSONValue] | None = None
  259. class ModelProviderDeleteRequestDto(BaseModel):
  260. providerId: str
  261. class ModelProviderTestRequestDto(BaseModel):
  262. providerId: str
  263. class ModelProviderTestData(BaseModel):
  264. success: bool
  265. message: str
  266. latencyMs: int | None = None
  267. modelList: list[str] = Field(default_factory=list)
  268. class DiscoverModelsRequestDto(BaseModel):
  269. providerId: str | None = None
  270. providerType: str | None = None
  271. baseUrl: HttpUrl | str | None = None
  272. apiKey: str | None = None
  273. class DiscoverModelsData(BaseModel):
  274. providerType: str
  275. models: list[ModelItemDto]
  276. def _mask_api_key(api_key: str | None) -> str:
  277. if not api_key:
  278. return ""
  279. return f"{api_key[:3]}***masked"
  280. def _to_snake_model_item(item: ModelItemDto) -> dict[str, JSONValue]:
  281. data: dict[str, JSONValue] = {
  282. "model_id": item.modelId,
  283. "display_name": item.displayName,
  284. "model_type": item.modelType,
  285. }
  286. if item.ownedBy is not None:
  287. data["owned_by"] = item.ownedBy
  288. if item.contextWindow is not None:
  289. data["context_window"] = item.contextWindow
  290. return data
  291. def _to_camel_model_item(item: dict[str, JSONValue]) -> dict[str, JSONValue]:
  292. return {
  293. "modelId": str(item.get("model_id") or item.get("modelId") or ""),
  294. "displayName": str(item.get("display_name") or item.get("displayName") or ""),
  295. "modelType": str(item.get("model_type") or item.get("modelType") or "chat"),
  296. "ownedBy": (
  297. item.get("owned_by")
  298. if isinstance(item.get("owned_by"), str)
  299. else item.get("ownedBy")
  300. ),
  301. "contextWindow": (
  302. item.get("context_window")
  303. if isinstance(item.get("context_window"), int)
  304. else item.get("contextWindow")
  305. ),
  306. }