Browse Source

feat: add pluggable knowledge embeddings

Jax Docker 1 tháng trước cách đây
mục cha
commit
f408976237

+ 5 - 3
README.md

@@ -473,9 +473,11 @@ Invoke-RestMethod -Method Post `
 ## Knowledge Service APIs
 
 `knowledge-service` stores independent knowledge bases, documents, chunks, and
-retrieval results. The first version uses deterministic local hash embeddings plus
-keyword scoring, so it works without external API keys. Later pgvector and provider
-embeddings can sit behind the same search contract.
+retrieval results. It defaults to deterministic local hash embeddings plus keyword
+scoring, so it works without external API keys. For production, set
+`AGENT_PLATFORM_EMBEDDING_PROVIDER=http` with an OpenAI-compatible
+`/embeddings` endpoint; if the provider fails and fallback is enabled, indexing
+and search fall back to local hash embeddings.
 
 Create a knowledge base:
 

+ 4 - 0
deployments/docker/.env.example

@@ -1,6 +1,10 @@
 AGENT_PLATFORM_PROVIDER_BASE_URL=https://api.openai.com/v1
 AGENT_PLATFORM_PROVIDER_API_KEY=replace-me
 AGENT_PLATFORM_DEFAULT_MODEL=gpt-4o-mini
+AGENT_PLATFORM_EMBEDDING_PROVIDER=local
+AGENT_PLATFORM_EMBEDDING_BASE_URL=
+AGENT_PLATFORM_EMBEDDING_API_KEY=
+AGENT_PLATFORM_EMBEDDING_MODEL=local-hash-v1
 AGENT_PLATFORM_MAX_TIMEOUT_SECONDS=30
 AGENT_PLATFORM_AUTH_REQUIRED=false
 AGENT_PLATFORM_AUTHZ_REQUIRED=false

+ 4 - 0
deployments/docker/docker-compose.yml

@@ -283,6 +283,10 @@ services:
     command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8012"]
     environment:
       AGENT_PLATFORM_DATABASE_URL: sqlite:////data/knowledge_service.db
+      AGENT_PLATFORM_EMBEDDING_PROVIDER: ${AGENT_PLATFORM_EMBEDDING_PROVIDER:-local}
+      AGENT_PLATFORM_EMBEDDING_BASE_URL: ${AGENT_PLATFORM_EMBEDDING_BASE_URL:-}
+      AGENT_PLATFORM_EMBEDDING_API_KEY: ${AGENT_PLATFORM_EMBEDDING_API_KEY:-}
+      AGENT_PLATFORM_EMBEDDING_MODEL: ${AGENT_PLATFORM_EMBEDDING_MODEL:-local-hash-v1}
     ports:
       - "8012:8012"
     volumes:

+ 90 - 0
services/knowledge-service/app/application/embeddings.py

@@ -0,0 +1,90 @@
+from dataclasses import dataclass
+
+import httpx
+
+from app.application.retrieval import build_hash_embedding
+from app.bootstrap.settings import KnowledgeServiceSettings
+
+
+class EmbeddingProviderError(Exception):
+    pass
+
+
+@dataclass(frozen=True)
+class EmbeddingResult:
+    embedding: list[float]
+    model: str
+    provider: str
+
+
+class EmbeddingService:
+    def __init__(self, *, settings: KnowledgeServiceSettings) -> None:
+        self.settings = settings
+
+    def embed_text(self, text: str) -> EmbeddingResult:
+        if self.settings.embedding_provider == "http":
+            try:
+                return self._embed_with_http(text)
+            except EmbeddingProviderError:
+                if not self.settings.embedding_fallback_to_local:
+                    raise
+        return self._embed_with_local_hash(text)
+
+    def _embed_with_local_hash(self, text: str) -> EmbeddingResult:
+        return EmbeddingResult(
+            embedding=build_hash_embedding(
+                text,
+                dimensions=self.settings.embedding_dimensions,
+            ),
+            model=self.settings.embedding_model,
+            provider="local-hash",
+        )
+
+    def _embed_with_http(self, text: str) -> EmbeddingResult:
+        if not self.settings.embedding_base_url:
+            raise EmbeddingProviderError("embedding_base_url is required for http provider")
+
+        headers: dict[str, str] = {}
+        if self.settings.embedding_api_key:
+            headers["Authorization"] = f"Bearer {self.settings.embedding_api_key}"
+
+        try:
+            with httpx.Client(timeout=self.settings.embedding_timeout_seconds) as client:
+                response = client.post(
+                    f"{self.settings.embedding_base_url.rstrip('/')}/embeddings",
+                    headers=headers,
+                    json={"model": self.settings.embedding_model, "input": text},
+                )
+                response.raise_for_status()
+                payload = response.json()
+        except (httpx.HTTPError, ValueError) as exc:
+            raise EmbeddingProviderError(f"http embedding request failed: {exc}") from exc
+
+        embedding = _read_openai_embedding(payload)
+        if embedding is None:
+            raise EmbeddingProviderError("embedding response missing data[0].embedding")
+        return EmbeddingResult(
+            embedding=embedding,
+            model=self.settings.embedding_model,
+            provider="http",
+        )
+
+
+def _read_openai_embedding(payload: object) -> list[float] | None:
+    if not isinstance(payload, dict):
+        return None
+    data = payload.get("data")
+    if not isinstance(data, list) or not data:
+        return None
+    first_item = data[0]
+    if not isinstance(first_item, dict):
+        return None
+    embedding = first_item.get("embedding")
+    if not isinstance(embedding, list):
+        return None
+    values: list[float] = []
+    for item in embedding:
+        if not isinstance(item, (int, float)) or isinstance(item, bool):
+            return None
+        values.append(float(item))
+    return values

+ 2 - 7
services/knowledge-service/app/application/retrieval.py

@@ -75,8 +75,6 @@ def build_chunk_payloads(
     content_text: str,
     chunk_size: int,
     chunk_overlap: int,
-    embedding_dimensions: int,
-    embedding_model: str,
 ) -> list[dict[str, JSONValue]]:
     chunks = split_text(
         content_text,
@@ -90,11 +88,8 @@ def build_chunk_payloads(
                 "chunk_index": index,
                 "content_text": chunk_text,
                 "token_count": len(tokenize(chunk_text)),
-                "embedding_model": embedding_model,
-                "embedding_json": build_hash_embedding(
-                    chunk_text,
-                    dimensions=embedding_dimensions,
-                ),
+                "embedding_model": None,
+                "embedding_json": None,
                 "metadata_json": {},
             }
         )

+ 19 - 9
services/knowledge-service/app/application/services.py

@@ -1,8 +1,8 @@
 from core_shared import JSONValue
 
+from app.application.embeddings import EmbeddingService
 from app.application.retrieval import (
     build_chunk_payloads,
-    build_hash_embedding,
     cosine_similarity,
     keyword_score,
     stable_content_hash,
@@ -35,6 +35,7 @@ class KnowledgeApplicationService:
         self.base_repository = base_repository
         self.document_repository = document_repository
         self.chunk_repository = chunk_repository
+        self.embedding_service = EmbeddingService(settings=settings)
 
     def create_base(self, payload: KnowledgeBaseCreateRequest) -> KnowledgeBase:
         return self.base_repository.create(
@@ -108,10 +109,7 @@ class KnowledgeApplicationService:
             knowledge_base_id=payload.knowledge_base_id,
         )
         document_cache: dict[str, KnowledgeDocument] = {}
-        query_embedding = build_hash_embedding(
-            payload.query,
-            dimensions=self.settings.embedding_dimensions,
-        )
+        query_embedding_result = self.embedding_service.embed_text(payload.query)
         scored: list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]] = []
         for chunk in chunks:
             document = document_cache.get(chunk.document_id)
@@ -126,7 +124,7 @@ class KnowledgeApplicationService:
             if not self._matches_filters(document=document, filters_json=payload.filters_json):
                 continue
             keyword = keyword_score(payload.query, chunk.content_text)
-            vector = cosine_similarity(query_embedding, chunk.embedding_json)
+            vector = cosine_similarity(query_embedding_result.embedding, chunk.embedding_json)
             score = round(keyword * 0.7 + vector * 0.3, 6)
             scored.append(
                 (
@@ -136,7 +134,9 @@ class KnowledgeApplicationService:
                     {
                         "keyword_score": round(keyword, 6),
                         "vector_score": round(vector, 6),
-                        "retrieval_mode": "hybrid-local",
+                        "retrieval_mode": "hybrid",
+                        "embedding_provider": query_embedding_result.provider,
+                        "embedding_model": query_embedding_result.model,
                     },
                 )
             )
@@ -153,9 +153,15 @@ class KnowledgeApplicationService:
             content_text=payload.content_text,
             chunk_size=payload.chunk_size or self.settings.default_chunk_size,
             chunk_overlap=payload.chunk_overlap or self.settings.default_chunk_overlap,
-            embedding_dimensions=self.settings.embedding_dimensions,
-            embedding_model=self.settings.embedding_model,
         )
+        for chunk_payload in chunk_payloads:
+            content_text = self._read_chunk_content(chunk_payload)
+            embedding_result = self.embedding_service.embed_text(content_text)
+            chunk_payload["embedding_model"] = embedding_result.model
+            chunk_payload["embedding_json"] = embedding_result.embedding
+            chunk_payload["metadata_json"] = {
+                "embedding_provider": embedding_result.provider,
+            }
         return self.chunk_repository.replace_document_chunks(
             tenant_id=document.tenant_id,
             knowledge_base_id=document.knowledge_base_id,
@@ -163,6 +169,10 @@ class KnowledgeApplicationService:
             chunks=chunk_payloads,
         )
 
+    def _read_chunk_content(self, chunk_payload: dict[str, JSONValue]) -> str:
+        value = chunk_payload.get("content_text")
+        return value if isinstance(value, str) else ""
+
     def _matches_filters(
         self,
         *,

+ 5 - 0
services/knowledge-service/app/bootstrap/settings.py

@@ -9,3 +9,8 @@ class KnowledgeServiceSettings(ServiceSettings):
     default_chunk_overlap: int = 120
     embedding_dimensions: int = 32
     embedding_model: str = "local-hash-v1"
+    embedding_provider: str = "local"
+    embedding_base_url: str | None = None
+    embedding_api_key: str | None = None
+    embedding_timeout_seconds: float = 30.0
+    embedding_fallback_to_local: bool = True

+ 1 - 0
services/knowledge-service/pyproject.toml

@@ -10,6 +10,7 @@ requires-python = ">=3.11"
 dependencies = [
   "alembic>=1.13,<2.0",
   "fastapi>=0.111,<1.0",
+  "httpx>=0.27,<1.0",
   "uvicorn[standard]>=0.30,<1.0",
   "pydantic>=2.7,<3.0",
   "sqlalchemy>=2.0,<3.0",