embeddings.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import logging
  2. from dataclasses import dataclass
  3. import httpx
  4. from app.application.retrieval import build_hash_embedding
  5. from app.bootstrap.settings import KnowledgeServiceSettings
  6. logger = logging.getLogger(__name__)
  7. class EmbeddingProviderError(Exception):
  8. pass
  9. @dataclass(frozen=True)
  10. class EmbeddingResult:
  11. embedding: list[float]
  12. model: str
  13. provider: str
  14. class EmbeddingService:
  15. def __init__(self, *, settings: KnowledgeServiceSettings) -> None:
  16. self.settings = settings
  17. def embed_text(self, text: str) -> EmbeddingResult:
  18. provider = self.settings.embedding_provider
  19. if provider == "model_gateway":
  20. try:
  21. return self._embed_with_model_gateway(text)
  22. except EmbeddingProviderError:
  23. if not self.settings.embedding_fallback_to_local:
  24. raise
  25. logger.warning("model_gateway embedding failed, falling back to local-hash")
  26. elif provider == "http":
  27. try:
  28. return self._embed_with_http(text)
  29. except EmbeddingProviderError:
  30. if not self.settings.embedding_fallback_to_local:
  31. raise
  32. logger.warning("http embedding failed, falling back to local-hash")
  33. return self._embed_with_local_hash(text)
  34. def embed_texts(self, texts: list[str]) -> list[EmbeddingResult]:
  35. if not texts:
  36. return []
  37. provider = self.settings.embedding_provider
  38. if provider == "model_gateway":
  39. try:
  40. return self._embed_batch_with_model_gateway(texts)
  41. except EmbeddingProviderError:
  42. if not self.settings.embedding_fallback_to_local:
  43. raise
  44. logger.warning("model_gateway batch embedding failed, falling back to local-hash")
  45. return [self._embed_with_local_hash(t) for t in texts]
  46. def _embed_with_local_hash(self, text: str) -> EmbeddingResult:
  47. return EmbeddingResult(
  48. embedding=build_hash_embedding(
  49. text,
  50. dimensions=self.settings.embedding_dimensions),
  51. model=self.settings.embedding_model,
  52. provider="local-hash")
  53. def _embed_with_model_gateway(self, text: str) -> EmbeddingResult:
  54. url = f"{self.settings.model_gateway_service_url.rstrip('/')}/models/embeddings"
  55. try:
  56. with httpx.Client(timeout=self.settings.model_gateway_timeout_seconds) as client:
  57. response = client.post(url, json={
  58. "model": self.settings.embedding_model,
  59. "input": text,
  60. "dimensions": self.settings.embedding_dimensions or None,
  61. })
  62. response.raise_for_status()
  63. payload = response.json()
  64. except (httpx.HTTPError, ValueError) as exc:
  65. raise EmbeddingProviderError(f"model_gateway embedding failed: {exc}") from exc
  66. embedding = _read_openai_embedding(payload)
  67. if embedding is None:
  68. raise EmbeddingProviderError("model_gateway response missing data[0].embedding")
  69. return EmbeddingResult(
  70. embedding=embedding,
  71. model=self.settings.embedding_model,
  72. provider="model_gateway")
  73. def _embed_batch_with_model_gateway(self, texts: list[str]) -> list[EmbeddingResult]:
  74. url = f"{self.settings.model_gateway_service_url.rstrip('/')}/models/embeddings"
  75. try:
  76. with httpx.Client(timeout=self.settings.model_gateway_timeout_seconds) as client:
  77. response = client.post(url, json={
  78. "model": self.settings.embedding_model,
  79. "input": texts,
  80. "dimensions": self.settings.embedding_dimensions or None,
  81. })
  82. response.raise_for_status()
  83. payload = response.json()
  84. except (httpx.HTTPError, ValueError) as exc:
  85. raise EmbeddingProviderError(f"model_gateway batch embedding failed: {exc}") from exc
  86. items = _read_openai_embedding_batch(payload)
  87. if len(items) != len(texts):
  88. raise EmbeddingProviderError(
  89. f"model_gateway returned {len(items)} embeddings, expected {len(texts)}")
  90. return [
  91. EmbeddingResult(embedding=emb, model=self.settings.embedding_model, provider="model_gateway")
  92. for emb in items
  93. ]
  94. def _embed_with_http(self, text: str) -> EmbeddingResult:
  95. if not self.settings.embedding_base_url:
  96. raise EmbeddingProviderError("embedding_base_url is required for http provider")
  97. headers: dict[str, str] = {}
  98. if self.settings.embedding_api_key:
  99. headers["Authorization"] = f"Bearer {self.settings.embedding_api_key}"
  100. try:
  101. with httpx.Client(timeout=self.settings.embedding_timeout_seconds) as client:
  102. response = client.post(
  103. f"{self.settings.embedding_base_url.rstrip('/')}/embeddings",
  104. headers=headers,
  105. json={"model": self.settings.embedding_model, "input": text})
  106. response.raise_for_status()
  107. payload = response.json()
  108. except (httpx.HTTPError, ValueError) as exc:
  109. raise EmbeddingProviderError(f"http embedding request failed: {exc}") from exc
  110. embedding = _read_openai_embedding(payload)
  111. if embedding is None:
  112. raise EmbeddingProviderError("embedding response missing data[0].embedding")
  113. return EmbeddingResult(
  114. embedding=embedding,
  115. model=self.settings.embedding_model,
  116. provider="http")
  117. def _read_openai_embedding(payload: object) -> list[float] | None:
  118. if not isinstance(payload, dict):
  119. return None
  120. data = payload.get("data")
  121. if not isinstance(data, list) or not data:
  122. return None
  123. first_item = data[0]
  124. if not isinstance(first_item, dict):
  125. return None
  126. return _extract_embedding_list(first_item.get("embedding"))
  127. def _read_openai_embedding_batch(payload: object) -> list[list[float]]:
  128. if not isinstance(payload, dict):
  129. return []
  130. data = payload.get("data")
  131. if not isinstance(data, list):
  132. return []
  133. results: list[list[float]] = []
  134. for item in sorted(data, key=lambda d: d.get("index", 0) if isinstance(d, dict) else 0):
  135. if not isinstance(item, dict):
  136. continue
  137. embedding = _extract_embedding_list(item.get("embedding"))
  138. if embedding is not None:
  139. results.append(embedding)
  140. return results
  141. def _extract_embedding_list(raw: object) -> list[float] | None:
  142. if not isinstance(raw, list):
  143. return None
  144. values: list[float] = []
  145. for item in raw:
  146. if not isinstance(item, (int, float)) or isinstance(item, bool):
  147. return None
  148. values.append(float(item))
  149. return values