| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import logging
- from dataclasses import dataclass
- import httpx
- from app.application.retrieval import build_hash_embedding
- from app.bootstrap.settings import KnowledgeServiceSettings
- logger = logging.getLogger(__name__)
- class EmbeddingProviderError(Exception):
- pass
- @dataclass(frozen=True)
- class EmbeddingResult:
- embedding: list[float]
- model: str
- provider: str
- class EmbeddingService:
- def __init__(self, *, settings: KnowledgeServiceSettings) -> None:
- self.settings = settings
- def embed_text(self, text: str) -> EmbeddingResult:
- provider = self.settings.embedding_provider
- if provider == "model_gateway":
- try:
- return self._embed_with_model_gateway(text)
- except EmbeddingProviderError:
- if not self.settings.embedding_fallback_to_local:
- raise
- logger.warning("model_gateway embedding failed, falling back to local-hash")
- elif provider == "http":
- try:
- return self._embed_with_http(text)
- except EmbeddingProviderError:
- if not self.settings.embedding_fallback_to_local:
- raise
- logger.warning("http embedding failed, falling back to local-hash")
- return self._embed_with_local_hash(text)
- def embed_texts(self, texts: list[str]) -> list[EmbeddingResult]:
- if not texts:
- return []
- provider = self.settings.embedding_provider
- if provider == "model_gateway":
- try:
- return self._embed_batch_with_model_gateway(texts)
- except EmbeddingProviderError:
- if not self.settings.embedding_fallback_to_local:
- raise
- logger.warning("model_gateway batch embedding failed, falling back to local-hash")
- return [self._embed_with_local_hash(t) for t in texts]
- def _embed_with_local_hash(self, text: str) -> EmbeddingResult:
- return EmbeddingResult(
- embedding=build_hash_embedding(
- text,
- dimensions=self.settings.embedding_dimensions),
- model=self.settings.embedding_model,
- provider="local-hash")
- def _embed_with_model_gateway(self, text: str) -> EmbeddingResult:
- url = f"{self.settings.model_gateway_service_url.rstrip('/')}/models/embeddings"
- try:
- with httpx.Client(timeout=self.settings.model_gateway_timeout_seconds) as client:
- response = client.post(url, json={
- "model": self.settings.embedding_model,
- "input": text,
- "dimensions": self.settings.embedding_dimensions or None,
- })
- response.raise_for_status()
- payload = response.json()
- except (httpx.HTTPError, ValueError) as exc:
- raise EmbeddingProviderError(f"model_gateway embedding failed: {exc}") from exc
- embedding = _read_openai_embedding(payload)
- if embedding is None:
- raise EmbeddingProviderError("model_gateway response missing data[0].embedding")
- return EmbeddingResult(
- embedding=embedding,
- model=self.settings.embedding_model,
- provider="model_gateway")
- def _embed_batch_with_model_gateway(self, texts: list[str]) -> list[EmbeddingResult]:
- url = f"{self.settings.model_gateway_service_url.rstrip('/')}/models/embeddings"
- try:
- with httpx.Client(timeout=self.settings.model_gateway_timeout_seconds) as client:
- response = client.post(url, json={
- "model": self.settings.embedding_model,
- "input": texts,
- "dimensions": self.settings.embedding_dimensions or None,
- })
- response.raise_for_status()
- payload = response.json()
- except (httpx.HTTPError, ValueError) as exc:
- raise EmbeddingProviderError(f"model_gateway batch embedding failed: {exc}") from exc
- items = _read_openai_embedding_batch(payload)
- if len(items) != len(texts):
- raise EmbeddingProviderError(
- f"model_gateway returned {len(items)} embeddings, expected {len(texts)}")
- return [
- EmbeddingResult(embedding=emb, model=self.settings.embedding_model, provider="model_gateway")
- for emb in items
- ]
- def _embed_with_http(self, text: str) -> EmbeddingResult:
- if not self.settings.embedding_base_url:
- raise EmbeddingProviderError("embedding_base_url is required for http provider")
- headers: dict[str, str] = {}
- if self.settings.embedding_api_key:
- headers["Authorization"] = f"Bearer {self.settings.embedding_api_key}"
- try:
- with httpx.Client(timeout=self.settings.embedding_timeout_seconds) as client:
- response = client.post(
- f"{self.settings.embedding_base_url.rstrip('/')}/embeddings",
- headers=headers,
- json={"model": self.settings.embedding_model, "input": text})
- response.raise_for_status()
- payload = response.json()
- except (httpx.HTTPError, ValueError) as exc:
- raise EmbeddingProviderError(f"http embedding request failed: {exc}") from exc
- embedding = _read_openai_embedding(payload)
- if embedding is None:
- raise EmbeddingProviderError("embedding response missing data[0].embedding")
- return EmbeddingResult(
- embedding=embedding,
- model=self.settings.embedding_model,
- provider="http")
- def _read_openai_embedding(payload: object) -> list[float] | None:
- if not isinstance(payload, dict):
- return None
- data = payload.get("data")
- if not isinstance(data, list) or not data:
- return None
- first_item = data[0]
- if not isinstance(first_item, dict):
- return None
- return _extract_embedding_list(first_item.get("embedding"))
- def _read_openai_embedding_batch(payload: object) -> list[list[float]]:
- if not isinstance(payload, dict):
- return []
- data = payload.get("data")
- if not isinstance(data, list):
- return []
- results: list[list[float]] = []
- for item in sorted(data, key=lambda d: d.get("index", 0) if isinstance(d, dict) else 0):
- if not isinstance(item, dict):
- continue
- embedding = _extract_embedding_list(item.get("embedding"))
- if embedding is not None:
- results.append(embedding)
- return results
- def _extract_embedding_list(raw: object) -> list[float] | None:
- if not isinstance(raw, list):
- return None
- values: list[float] = []
- for item in raw:
- if not isinstance(item, (int, float)) or isinstance(item, bool):
- return None
- values.append(float(item))
- return values
|