embeddings.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from dataclasses import dataclass
  2. import httpx
  3. from app.application.retrieval import build_hash_embedding
  4. from app.bootstrap.settings import KnowledgeServiceSettings
  5. class EmbeddingProviderError(Exception):
  6. pass
  7. @dataclass(frozen=True)
  8. class EmbeddingResult:
  9. embedding: list[float]
  10. model: str
  11. provider: str
  12. class EmbeddingService:
  13. def __init__(self, *, settings: KnowledgeServiceSettings) -> None:
  14. self.settings = settings
  15. def embed_text(self, text: str) -> EmbeddingResult:
  16. if self.settings.embedding_provider == "http":
  17. try:
  18. return self._embed_with_http(text)
  19. except EmbeddingProviderError:
  20. if not self.settings.embedding_fallback_to_local:
  21. raise
  22. return self._embed_with_local_hash(text)
  23. def _embed_with_local_hash(self, text: str) -> EmbeddingResult:
  24. return EmbeddingResult(
  25. embedding=build_hash_embedding(
  26. text,
  27. dimensions=self.settings.embedding_dimensions),
  28. model=self.settings.embedding_model,
  29. provider="local-hash")
  30. def _embed_with_http(self, text: str) -> EmbeddingResult:
  31. if not self.settings.embedding_base_url:
  32. raise EmbeddingProviderError("embedding_base_url is required for http provider")
  33. headers: dict[str, str] = {}
  34. if self.settings.embedding_api_key:
  35. headers["Authorization"] = f"Bearer {self.settings.embedding_api_key}"
  36. try:
  37. with httpx.Client(timeout=self.settings.embedding_timeout_seconds) as client:
  38. response = client.post(
  39. f"{self.settings.embedding_base_url.rstrip('/')}/embeddings",
  40. headers=headers,
  41. json={"model": self.settings.embedding_model, "input": text})
  42. response.raise_for_status()
  43. payload = response.json()
  44. except (httpx.HTTPError, ValueError) as exc:
  45. raise EmbeddingProviderError(f"http embedding request failed: {exc}") from exc
  46. embedding = _read_openai_embedding(payload)
  47. if embedding is None:
  48. raise EmbeddingProviderError("embedding response missing data[0].embedding")
  49. return EmbeddingResult(
  50. embedding=embedding,
  51. model=self.settings.embedding_model,
  52. provider="http")
  53. def _read_openai_embedding(payload: object) -> list[float] | None:
  54. if not isinstance(payload, dict):
  55. return None
  56. data = payload.get("data")
  57. if not isinstance(data, list) or not data:
  58. return None
  59. first_item = data[0]
  60. if not isinstance(first_item, dict):
  61. return None
  62. embedding = first_item.get("embedding")
  63. if not isinstance(embedding, list):
  64. return None
  65. values: list[float] = []
  66. for item in embedding:
  67. if not isinstance(item, (int, float)) or isinstance(item, bool):
  68. return None
  69. values.append(float(item))
  70. return values