embeddings.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from dataclasses import dataclass
  2. import httpx
  3. from app.application.retrieval import build_hash_embedding
  4. from app.bootstrap.settings import KnowledgeServiceSettings
  5. class EmbeddingProviderError(Exception):
  6. pass
  7. @dataclass(frozen=True)
  8. class EmbeddingResult:
  9. embedding: list[float]
  10. model: str
  11. provider: str
  12. class EmbeddingService:
  13. def __init__(self, *, settings: KnowledgeServiceSettings) -> None:
  14. self.settings = settings
  15. def embed_text(self, text: str) -> EmbeddingResult:
  16. if self.settings.embedding_provider == "http":
  17. try:
  18. return self._embed_with_http(text)
  19. except EmbeddingProviderError:
  20. if not self.settings.embedding_fallback_to_local:
  21. raise
  22. return self._embed_with_local_hash(text)
  23. def _embed_with_local_hash(self, text: str) -> EmbeddingResult:
  24. return EmbeddingResult(
  25. embedding=build_hash_embedding(
  26. text,
  27. dimensions=self.settings.embedding_dimensions,
  28. ),
  29. model=self.settings.embedding_model,
  30. provider="local-hash",
  31. )
  32. def _embed_with_http(self, text: str) -> EmbeddingResult:
  33. if not self.settings.embedding_base_url:
  34. raise EmbeddingProviderError("embedding_base_url is required for http provider")
  35. headers: dict[str, str] = {}
  36. if self.settings.embedding_api_key:
  37. headers["Authorization"] = f"Bearer {self.settings.embedding_api_key}"
  38. try:
  39. with httpx.Client(timeout=self.settings.embedding_timeout_seconds) as client:
  40. response = client.post(
  41. f"{self.settings.embedding_base_url.rstrip('/')}/embeddings",
  42. headers=headers,
  43. json={"model": self.settings.embedding_model, "input": text},
  44. )
  45. response.raise_for_status()
  46. payload = response.json()
  47. except (httpx.HTTPError, ValueError) as exc:
  48. raise EmbeddingProviderError(f"http embedding request failed: {exc}") from exc
  49. embedding = _read_openai_embedding(payload)
  50. if embedding is None:
  51. raise EmbeddingProviderError("embedding response missing data[0].embedding")
  52. return EmbeddingResult(
  53. embedding=embedding,
  54. model=self.settings.embedding_model,
  55. provider="http",
  56. )
  57. def _read_openai_embedding(payload: object) -> list[float] | None:
  58. if not isinstance(payload, dict):
  59. return None
  60. data = payload.get("data")
  61. if not isinstance(data, list) or not data:
  62. return None
  63. first_item = data[0]
  64. if not isinstance(first_item, dict):
  65. return None
  66. embedding = first_item.get("embedding")
  67. if not isinstance(embedding, list):
  68. return None
  69. values: list[float] = []
  70. for item in embedding:
  71. if not isinstance(item, (int, float)) or isinstance(item, bool):
  72. return None
  73. values.append(float(item))
  74. return values