Sfoglia il codice sorgente

feat: add rag rerank and trace propagation

Jax Docker 1 mese fa
parent
commit
78bd5941ea

+ 4 - 0
deployments/docker/.env.example

@@ -23,6 +23,10 @@ AGENT_PLATFORM_EMBEDDING_PROVIDER=local
 AGENT_PLATFORM_EMBEDDING_BASE_URL=
 AGENT_PLATFORM_EMBEDDING_API_KEY=
 AGENT_PLATFORM_EMBEDDING_MODEL=local-hash-v1
+AGENT_PLATFORM_RETRIEVAL_KEYWORD_WEIGHT=0.55
+AGENT_PLATFORM_RETRIEVAL_VECTOR_WEIGHT=0.30
+AGENT_PLATFORM_RETRIEVAL_RERANK_WEIGHT=0.15
+AGENT_PLATFORM_RETRIEVAL_RERANK_ENABLED=true
 AGENT_PLATFORM_MAX_TIMEOUT_SECONDS=30
 AGENT_PLATFORM_AUTH_REQUIRED=false
 AGENT_PLATFORM_AUTHZ_REQUIRED=false

+ 4 - 0
deployments/docker/docker-compose.yml

@@ -393,6 +393,10 @@ services:
       AGENT_PLATFORM_EMBEDDING_BASE_URL: ${AGENT_PLATFORM_EMBEDDING_BASE_URL:-}
       AGENT_PLATFORM_EMBEDDING_API_KEY: ${AGENT_PLATFORM_EMBEDDING_API_KEY:-}
       AGENT_PLATFORM_EMBEDDING_MODEL: ${AGENT_PLATFORM_EMBEDDING_MODEL:-local-hash-v1}
+      AGENT_PLATFORM_RETRIEVAL_KEYWORD_WEIGHT: ${AGENT_PLATFORM_RETRIEVAL_KEYWORD_WEIGHT:-0.55}
+      AGENT_PLATFORM_RETRIEVAL_VECTOR_WEIGHT: ${AGENT_PLATFORM_RETRIEVAL_VECTOR_WEIGHT:-0.30}
+      AGENT_PLATFORM_RETRIEVAL_RERANK_WEIGHT: ${AGENT_PLATFORM_RETRIEVAL_RERANK_WEIGHT:-0.15}
+      AGENT_PLATFORM_RETRIEVAL_RERANK_ENABLED: ${AGENT_PLATFORM_RETRIEVAL_RERANK_ENABLED:-true}
     ports:
       - "8012:8012"
     volumes:

+ 29 - 0
libs/core-shared/src/core_shared/observability.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import json
 import logging
+from uuid import uuid4
 from collections import defaultdict
 from dataclasses import dataclass
 from time import perf_counter
@@ -17,6 +18,9 @@ from starlette.types import ASGIApp
 
 _METRICS_CONTENT_TYPE = "text/plain; version=0.0.4; charset=utf-8"
 _DURATION_BUCKETS = (0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0)
+TRACE_ID_HEADER = "x-trace-id"
+SPAN_ID_HEADER = "x-span-id"
+PARENT_SPAN_ID_HEADER = "x-parent-span-id"
 RouteDecorator = Callable[[Callable[..., Awaitable[Response]]], Callable[..., Awaitable[Response]]]
 
 
@@ -125,9 +129,19 @@ class ObservabilityMiddleware(BaseHTTPMiddleware):
 
         started_at_monotonic = perf_counter()
         status_code = 500
+        trace_id = _resolve_trace_id(request.headers)
+        parent_span_id = _header(request.headers, SPAN_ID_HEADER)
+        span_id = uuid4().hex[:16]
+        request.state.trace_id = trace_id
+        request.state.span_id = span_id
+        request.state.parent_span_id = parent_span_id
         try:
             response = await call_next(request)
             status_code = response.status_code
+            response.headers[TRACE_ID_HEADER] = trace_id
+            response.headers[SPAN_ID_HEADER] = span_id
+            if parent_span_id is not None:
+                response.headers[PARENT_SPAN_ID_HEADER] = parent_span_id
             return response
         finally:
             duration_seconds = perf_counter() - started_at_monotonic
@@ -158,6 +172,9 @@ class ObservabilityMiddleware(BaseHTTPMiddleware):
             "duration_ms": round(duration_seconds * 1000, 3),
             "request_id": _header(headers, "x-request-id"),
             "tenant_id": _header(headers, "x-tenant-id"),
+            "trace_id": getattr(request.state, "trace_id", None),
+            "span_id": getattr(request.state, "span_id", None),
+            "parent_span_id": getattr(request.state, "parent_span_id", None),
         }
         self._logger.info(json.dumps(payload, ensure_ascii=False, separators=(",", ":")))
 
@@ -191,6 +208,18 @@ def _header(headers: Headers, name: str) -> str | None:
     return value
 
 
+def _resolve_trace_id(headers: Headers) -> str:
+    trace_id = _header(headers, TRACE_ID_HEADER)
+    if trace_id is not None:
+        return trace_id
+    traceparent = _header(headers, "traceparent")
+    if traceparent is not None:
+        parts = traceparent.split("-")
+        if len(parts) >= 2 and parts[1]:
+            return parts[1]
+    return uuid4().hex
+
+
 def _format_bucket(value: float) -> str:
     return f"{value:g}"
 

+ 10 - 0
services/api-gateway/app/infrastructure/proxy.py

@@ -3,6 +3,7 @@ from typing import Literal
 
 import httpx
 from fastapi import Request, Response
+from core_shared.observability import PARENT_SPAN_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER
 from core_shared.security import build_internal_service_headers
 
 from app.infrastructure.audit import mark_gateway_target
@@ -63,6 +64,12 @@ class ServiceProxy:
         request_context = get_gateway_request_context(request)
         headers[REQUEST_ID_HEADER] = request_context.request_id
         headers[TENANT_ID_HEADER] = request_context.tenant_id
+        trace_id = getattr(request.state, "trace_id", None)
+        span_id = getattr(request.state, "span_id", None)
+        if isinstance(trace_id, str):
+            headers[TRACE_ID_HEADER] = trace_id
+        if isinstance(span_id, str):
+            headers[PARENT_SPAN_ID_HEADER] = span_id
         headers.update(build_internal_service_headers(self.settings))
         body = await request.body()
 
@@ -123,6 +130,9 @@ def build_forward_headers(request: Request) -> dict[str, str]:
         TENANT_ID_HEADER,
         "x-internal-service-token",
         "x-internal-service-name",
+        TRACE_ID_HEADER,
+        SPAN_ID_HEADER,
+        PARENT_SPAN_ID_HEADER,
     }
     return {
         key: value

+ 21 - 0
services/knowledge-service/app/application/retrieval.py

@@ -66,6 +66,27 @@ def keyword_score(query: str, text: str) -> float:
     return matched / len(set(query_tokens)) + min(frequency / 20.0, 1.0)
 
 
+def rerank_score(*, query: str, chunk_text: str, document_title: str | None = None) -> float:
+    query_tokens = tokenize(query)
+    if not query_tokens:
+        return 0.0
+    chunk_tokens = tokenize(chunk_text)
+    title_tokens = tokenize(document_title or "")
+    if not chunk_tokens and not title_tokens:
+        return 0.0
+    unique_query_tokens = set(query_tokens)
+    chunk_token_set = set(chunk_tokens)
+    title_token_set = set(title_tokens)
+    coverage = len(unique_query_tokens & chunk_token_set) / len(unique_query_tokens)
+    title_bonus = min(len(unique_query_tokens & title_token_set) / len(unique_query_tokens), 1.0)
+    phrase_bonus = 1.0 if query.lower() in chunk_text.lower() else 0.0
+    density = sum(1 for token in chunk_tokens if token in unique_query_tokens) / max(
+        len(chunk_tokens),
+        1,
+    )
+    return min(coverage * 0.55 + title_bonus * 0.2 + phrase_bonus * 0.15 + density * 0.1, 1.0)
+
+
 def stable_content_hash(text: str) -> str:
     return hashlib.sha256(text.encode("utf-8")).hexdigest()
 

+ 37 - 2
services/knowledge-service/app/application/services.py

@@ -10,6 +10,7 @@ from app.application.retrieval import (
     build_chunk_payloads,
     cosine_similarity,
     keyword_score,
+    rerank_score,
     stable_content_hash,
 )
 from app.bootstrap.settings import KnowledgeServiceSettings
@@ -140,11 +141,15 @@ class KnowledgeApplicationService:
     ) -> list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]]:
         document_cache: dict[str, KnowledgeDocument] = {}
         query_embedding_result = self.embedding_service.embed_text(payload.query)
+        candidate_limit = max(
+            payload.top_k * max(self.settings.retrieval_candidate_multiplier, 1),
+            payload.top_k,
+        )
         vector_candidates = self.chunk_repository.search_by_vector(
             tenant_id=payload.tenant_id,
             knowledge_base_id=payload.knowledge_base_id,
             embedding=query_embedding_result.embedding,
-            limit=max(payload.top_k * 5, payload.top_k),
+            limit=candidate_limit,
         )
         if vector_candidates:
             chunks = [chunk for chunk, _ in vector_candidates]
@@ -176,18 +181,48 @@ class KnowledgeApplicationService:
             vector = vector_scores_by_chunk_id.get(chunk.id)
             if vector is None:
                 vector = cosine_similarity(query_embedding_result.embedding, chunk.embedding_json)
-            score = round(keyword * 0.7 + vector * 0.3, 6)
+            rerank = (
+                rerank_score(
+                    query=payload.query,
+                    chunk_text=chunk.content_text,
+                    document_title=document.title,
+                )
+                if self.settings.retrieval_rerank_enabled
+                else 0.0
+            )
+            score = round(
+                keyword * self.settings.retrieval_keyword_weight
+                + vector * self.settings.retrieval_vector_weight
+                + rerank * self.settings.retrieval_rerank_weight,
+                6,
+            )
             scored.append(
                 (
                     chunk,
                     document,
                     score,
                     {
+                        "final_score": score,
                         "keyword_score": round(keyword, 6),
                         "vector_score": round(vector, 6),
+                        "rerank_score": round(rerank, 6),
                         "retrieval_mode": retrieval_mode,
+                        "rerank_enabled": self.settings.retrieval_rerank_enabled,
+                        "candidate_limit": candidate_limit,
+                        "weights": {
+                            "keyword": self.settings.retrieval_keyword_weight,
+                            "vector": self.settings.retrieval_vector_weight,
+                            "rerank": self.settings.retrieval_rerank_weight,
+                        },
                         "embedding_provider": query_embedding_result.provider,
                         "embedding_model": query_embedding_result.model,
+                        "citation": {
+                            "document_id": document.id,
+                            "document_title": document.title,
+                            "source_uri": document.source_uri,
+                            "chunk_id": chunk.id,
+                            "chunk_index": chunk.chunk_index,
+                        },
                     },
                 )
             )

+ 5 - 0
services/knowledge-service/app/bootstrap/settings.py

@@ -14,3 +14,8 @@ class KnowledgeServiceSettings(ServiceSettings):
     embedding_api_key: str | None = None
     embedding_timeout_seconds: float = 30.0
     embedding_fallback_to_local: bool = True
+    retrieval_keyword_weight: float = 0.55
+    retrieval_vector_weight: float = 0.30
+    retrieval_rerank_weight: float = 0.15
+    retrieval_rerank_enabled: bool = True
+    retrieval_candidate_multiplier: int = 5

+ 17 - 0
tests/test_knowledge_document_parsers.py

@@ -15,6 +15,7 @@ for path in [
     sys.path.insert(0, str(path))
 
 from app.application.document_parsers import parse_document_content
+from app.application.retrieval import rerank_score
 
 
 def test_parse_markdown_html_json_csv_documents() -> None:
@@ -41,3 +42,19 @@ def test_parse_markdown_html_json_csv_documents() -> None:
     assert "Hello world" in html.content_text
     assert "order.id: A1" in json_doc.content_text
     assert "row 2: id: A2; status: refunded" in csv_doc.content_text
+
+
+def test_rerank_score_prefers_title_and_phrase_matches() -> None:
+    strong = rerank_score(
+        query="refund policy",
+        chunk_text="The refund policy allows refunds within seven days.",
+        document_title="Refund Policy",
+    )
+    weak = rerank_score(
+        query="refund policy",
+        chunk_text="Shipping times are usually three days.",
+        document_title="Shipping",
+    )
+
+    assert strong > weak
+    assert strong > 0.5

+ 3 - 0
tests/test_knowledge_pgvector_fallback.py

@@ -75,6 +75,9 @@ def test_knowledge_search_falls_back_without_pgvector(tmp_path: Path) -> None:
 
         assert results
         assert results[0][3]["retrieval_mode"] == "hybrid"
+        assert results[0][3]["rerank_enabled"] is True
+        assert "citation" in results[0][3]
+        assert results[0][3]["weights"]["rerank"] == settings.retrieval_rerank_weight
     session_factory.kw["bind"].dispose()
 
 

+ 23 - 0
tests/test_observability.py

@@ -24,8 +24,31 @@ async def _run_observability_smoke() -> None:
 
     assert health_response.status_code == 200
     assert metrics_response.status_code == 200
+    assert health_response.headers["x-trace-id"]
+    assert health_response.headers["x-span-id"]
     assert 'agent_platform_service_info{service="test-service"} 1' in metrics_response.text
     assert "agent_platform_http_requests_total" in metrics_response.text
     assert 'method="GET"' in metrics_response.text
     assert 'path="/health"' in metrics_response.text
     assert 'status_code="200"' in metrics_response.text
+
+
+def test_observability_preserves_incoming_trace_id() -> None:
+    asyncio.run(_run_trace_propagation_smoke())
+
+
+async def _run_trace_propagation_smoke() -> None:
+    app = FastAPI()
+    add_observability(app, "test-service")
+
+    @app.get("/health")
+    async def health() -> dict[str, str]:
+        return {"status": "ok"}
+
+    transport = httpx.ASGITransport(app=app)
+    async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
+        response = await client.get("/health", headers={"x-trace-id": "trace-123"})
+
+    assert response.status_code == 200
+    assert response.headers["x-trace-id"] == "trace-123"
+    assert response.headers["x-span-id"]