import logging from dataclasses import dataclass import httpx from app.application.retrieval import build_hash_embedding from app.bootstrap.settings import KnowledgeServiceSettings logger = logging.getLogger(__name__) class EmbeddingProviderError(Exception): pass @dataclass(frozen=True) class EmbeddingResult: embedding: list[float] model: str provider: str class EmbeddingService: def __init__(self, *, settings: KnowledgeServiceSettings) -> None: self.settings = settings def embed_text(self, text: str) -> EmbeddingResult: provider = self.settings.embedding_provider if provider == "model_gateway": try: return self._embed_with_model_gateway(text) except EmbeddingProviderError: if not self.settings.embedding_fallback_to_local: raise logger.warning("model_gateway embedding failed, falling back to local-hash") elif provider == "http": try: return self._embed_with_http(text) except EmbeddingProviderError: if not self.settings.embedding_fallback_to_local: raise logger.warning("http embedding failed, falling back to local-hash") return self._embed_with_local_hash(text) def embed_texts(self, texts: list[str]) -> list[EmbeddingResult]: if not texts: return [] provider = self.settings.embedding_provider if provider == "model_gateway": try: return self._embed_batch_with_model_gateway(texts) except EmbeddingProviderError: if not self.settings.embedding_fallback_to_local: raise logger.warning("model_gateway batch embedding failed, falling back to local-hash") return [self._embed_with_local_hash(t) for t in texts] def _embed_with_local_hash(self, text: str) -> EmbeddingResult: return EmbeddingResult( embedding=build_hash_embedding( text, dimensions=self.settings.embedding_dimensions), model=self.settings.embedding_model, provider="local-hash") def _embed_with_model_gateway(self, text: str) -> EmbeddingResult: url = f"{self.settings.model_gateway_service_url.rstrip('/')}/models/embeddings" try: with httpx.Client(timeout=self.settings.model_gateway_timeout_seconds) as client: response = client.post(url, json={ "model": self.settings.embedding_model, "input": text, "dimensions": self.settings.embedding_dimensions or None, }) response.raise_for_status() payload = response.json() except (httpx.HTTPError, ValueError) as exc: raise EmbeddingProviderError(f"model_gateway embedding failed: {exc}") from exc embedding = _read_openai_embedding(payload) if embedding is None: raise EmbeddingProviderError("model_gateway response missing data[0].embedding") return EmbeddingResult( embedding=embedding, model=self.settings.embedding_model, provider="model_gateway") def _embed_batch_with_model_gateway(self, texts: list[str]) -> list[EmbeddingResult]: url = f"{self.settings.model_gateway_service_url.rstrip('/')}/models/embeddings" try: with httpx.Client(timeout=self.settings.model_gateway_timeout_seconds) as client: response = client.post(url, json={ "model": self.settings.embedding_model, "input": texts, "dimensions": self.settings.embedding_dimensions or None, }) response.raise_for_status() payload = response.json() except (httpx.HTTPError, ValueError) as exc: raise EmbeddingProviderError(f"model_gateway batch embedding failed: {exc}") from exc items = _read_openai_embedding_batch(payload) if len(items) != len(texts): raise EmbeddingProviderError( f"model_gateway returned {len(items)} embeddings, expected {len(texts)}") return [ EmbeddingResult(embedding=emb, model=self.settings.embedding_model, provider="model_gateway") for emb in items ] def _embed_with_http(self, text: str) -> EmbeddingResult: if not self.settings.embedding_base_url: raise EmbeddingProviderError("embedding_base_url is required for http provider") headers: dict[str, str] = {} if self.settings.embedding_api_key: headers["Authorization"] = f"Bearer {self.settings.embedding_api_key}" try: with httpx.Client(timeout=self.settings.embedding_timeout_seconds) as client: response = client.post( f"{self.settings.embedding_base_url.rstrip('/')}/embeddings", headers=headers, json={"model": self.settings.embedding_model, "input": text}) response.raise_for_status() payload = response.json() except (httpx.HTTPError, ValueError) as exc: raise EmbeddingProviderError(f"http embedding request failed: {exc}") from exc embedding = _read_openai_embedding(payload) if embedding is None: raise EmbeddingProviderError("embedding response missing data[0].embedding") return EmbeddingResult( embedding=embedding, model=self.settings.embedding_model, provider="http") def _read_openai_embedding(payload: object) -> list[float] | None: if not isinstance(payload, dict): return None data = payload.get("data") if not isinstance(data, list) or not data: return None first_item = data[0] if not isinstance(first_item, dict): return None return _extract_embedding_list(first_item.get("embedding")) def _read_openai_embedding_batch(payload: object) -> list[list[float]]: if not isinstance(payload, dict): return [] data = payload.get("data") if not isinstance(data, list): return [] results: list[list[float]] = [] for item in sorted(data, key=lambda d: d.get("index", 0) if isinstance(d, dict) else 0): if not isinstance(item, dict): continue embedding = _extract_embedding_list(item.get("embedding")) if embedding is not None: results.append(embedding) return results def _extract_embedding_list(raw: object) -> list[float] | None: if not isinstance(raw, list): return None values: list[float] = [] for item in raw: if not isinstance(item, (int, float)) or isinstance(item, bool): return None values.append(float(item)) return values