| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- """Knowledge search orchestration sub-service."""
- from __future__ import annotations
- from typing import TYPE_CHECKING
- from core_shared import JSONValue
- from app.application.embeddings import EmbeddingService
- from app.application.retrieval import (
- bm25_score,
- compute_bm25_stats,
- cosine_similarity,
- rerank_score,
- )
- from app.bootstrap.settings import KnowledgeServiceSettings
- from app.db.models import KnowledgeChunk, KnowledgeDocument
- from app.schemas.knowledge import KnowledgeSearchRequest
- if TYPE_CHECKING:
- from app.domain.repositories import (
- KnowledgeBaseRepository,
- KnowledgeChunkRepository,
- KnowledgeDocumentRepository,
- )
- class KnowledgeSearchService:
- def __init__(
- self,
- *,
- settings: KnowledgeServiceSettings,
- base_repository: KnowledgeBaseRepository,
- document_repository: KnowledgeDocumentRepository,
- chunk_repository: KnowledgeChunkRepository,
- embedding_service: EmbeddingService,
- ) -> None:
- self.settings = settings
- self.base_repository = base_repository
- self.document_repository = document_repository
- self.chunk_repository = chunk_repository
- self.embedding_service = embedding_service
- def search(
- self,
- payload: KnowledgeSearchRequest,
- ) -> list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]]:
- document_cache: dict[str, KnowledgeDocument] = {}
- query_embedding_result = self.embedding_service.embed_text(payload.query)
- candidate_limit = max(
- payload.top_k * max(self.settings.retrieval_candidate_multiplier, 1),
- payload.top_k,
- )
- vector_candidates = self.chunk_repository.search_by_vector(
- knowledge_base_id=payload.knowledge_base_id,
- embedding=query_embedding_result.embedding,
- limit=candidate_limit,
- )
- if vector_candidates:
- chunks = [chunk for chunk, _ in vector_candidates]
- vector_scores_by_chunk_id = {
- chunk.id: score for chunk, score in vector_candidates
- }
- retrieval_mode = "pgvector-hybrid"
- else:
- chunks = self.chunk_repository.list_by_base(
- knowledge_base_id=payload.knowledge_base_id,
- )
- vector_scores_by_chunk_id = {}
- retrieval_mode = "hybrid"
- kb = self.base_repository.get_by_id(knowledge_base_id=payload.knowledge_base_id)
- retrieval_config = (kb.metadata_json or {}).get("retrieval_config", {}) if kb else {}
- keyword_weight = float(retrieval_config.get("keyword_weight", self.settings.retrieval_keyword_weight))
- vector_weight = float(retrieval_config.get("vector_weight", self.settings.retrieval_vector_weight))
- rerank_weight = float(retrieval_config.get("rerank_weight", self.settings.retrieval_rerank_weight))
- chunk_texts = [chunk.content_text for chunk in chunks]
- avg_doc_length, doc_count, df_map = compute_bm25_stats(chunk_texts)
- scored: list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]] = []
- for chunk in chunks:
- document = document_cache.get(chunk.document_id)
- if document is None:
- document = self.document_repository.get_by_id(document_id=chunk.document_id)
- if document is None:
- continue
- document_cache[chunk.document_id] = document
- if not self._matches_filters(document=document, filters_json=payload.filters_json):
- continue
- keyword = bm25_score(
- payload.query, chunk.content_text,
- avg_doc_length=avg_doc_length, doc_count=doc_count, df=df_map,
- )
- vector = vector_scores_by_chunk_id.get(chunk.id)
- if vector is None:
- vector = cosine_similarity(query_embedding_result.embedding, chunk.embedding_json)
- rerank = (
- rerank_score(
- query=payload.query,
- chunk_text=chunk.content_text,
- document_title=document.title,
- )
- if self.settings.retrieval_rerank_enabled
- else 0.0
- )
- score = round(
- keyword * keyword_weight
- + vector * vector_weight
- + rerank * rerank_weight,
- 6,
- )
- scored.append((
- chunk,
- document,
- score,
- {
- "final_score": score,
- "keyword_score": round(keyword, 6),
- "vector_score": round(vector, 6),
- "rerank_score": round(rerank, 6),
- "retrieval_mode": retrieval_mode,
- "rerank_enabled": self.settings.retrieval_rerank_enabled,
- "candidate_limit": candidate_limit,
- "weights": {
- "keyword": keyword_weight,
- "vector": vector_weight,
- "rerank": rerank_weight,
- },
- "embedding_provider": query_embedding_result.provider,
- "embedding_model": query_embedding_result.model,
- "citation": {
- "document_id": document.id,
- "document_title": document.title,
- "source_uri": document.source_uri,
- "chunk_id": chunk.id,
- "chunk_index": chunk.chunk_index,
- },
- },
- ))
- scored.sort(key=lambda item: item[2], reverse=True)
- return scored[: payload.top_k]
- @staticmethod
- def _matches_filters(
- *,
- document: KnowledgeDocument,
- filters_json: dict[str, JSONValue],
- ) -> bool:
- source_type = filters_json.get("sourceType") or filters_json.get("source_type")
- if isinstance(source_type, str) and document.source_type != source_type:
- return False
- status = filters_json.get("status")
- if isinstance(status, str) and document.status != status:
- return False
- return True
|