search_service.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """Knowledge search orchestration sub-service."""
  2. from __future__ import annotations
  3. from typing import TYPE_CHECKING
  4. from core_shared import JSONValue
  5. from app.application.embeddings import EmbeddingService
  6. from app.application.retrieval import (
  7. bm25_score,
  8. compute_bm25_stats,
  9. cosine_similarity,
  10. rerank_score,
  11. )
  12. from app.bootstrap.settings import KnowledgeServiceSettings
  13. from app.db.models import KnowledgeChunk, KnowledgeDocument
  14. from app.schemas.knowledge import KnowledgeSearchRequest
  15. if TYPE_CHECKING:
  16. from app.domain.repositories import (
  17. KnowledgeBaseRepository,
  18. KnowledgeChunkRepository,
  19. KnowledgeDocumentRepository,
  20. )
  21. class KnowledgeSearchService:
  22. def __init__(
  23. self,
  24. *,
  25. settings: KnowledgeServiceSettings,
  26. base_repository: KnowledgeBaseRepository,
  27. document_repository: KnowledgeDocumentRepository,
  28. chunk_repository: KnowledgeChunkRepository,
  29. embedding_service: EmbeddingService,
  30. ) -> None:
  31. self.settings = settings
  32. self.base_repository = base_repository
  33. self.document_repository = document_repository
  34. self.chunk_repository = chunk_repository
  35. self.embedding_service = embedding_service
  36. def search(
  37. self,
  38. payload: KnowledgeSearchRequest,
  39. ) -> list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]]:
  40. document_cache: dict[str, KnowledgeDocument] = {}
  41. query_embedding_result = self.embedding_service.embed_text(payload.query)
  42. candidate_limit = max(
  43. payload.top_k * max(self.settings.retrieval_candidate_multiplier, 1),
  44. payload.top_k,
  45. )
  46. vector_candidates = self.chunk_repository.search_by_vector(
  47. knowledge_base_id=payload.knowledge_base_id,
  48. embedding=query_embedding_result.embedding,
  49. limit=candidate_limit,
  50. )
  51. if vector_candidates:
  52. chunks = [chunk for chunk, _ in vector_candidates]
  53. vector_scores_by_chunk_id = {
  54. chunk.id: score for chunk, score in vector_candidates
  55. }
  56. retrieval_mode = "pgvector-hybrid"
  57. else:
  58. chunks = self.chunk_repository.list_by_base(
  59. knowledge_base_id=payload.knowledge_base_id,
  60. )
  61. vector_scores_by_chunk_id = {}
  62. retrieval_mode = "hybrid"
  63. kb = self.base_repository.get_by_id(knowledge_base_id=payload.knowledge_base_id)
  64. retrieval_config = (kb.metadata_json or {}).get("retrieval_config", {}) if kb else {}
  65. keyword_weight = float(retrieval_config.get("keyword_weight", self.settings.retrieval_keyword_weight))
  66. vector_weight = float(retrieval_config.get("vector_weight", self.settings.retrieval_vector_weight))
  67. rerank_weight = float(retrieval_config.get("rerank_weight", self.settings.retrieval_rerank_weight))
  68. chunk_texts = [chunk.content_text for chunk in chunks]
  69. avg_doc_length, doc_count, df_map = compute_bm25_stats(chunk_texts)
  70. scored: list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]] = []
  71. for chunk in chunks:
  72. document = document_cache.get(chunk.document_id)
  73. if document is None:
  74. document = self.document_repository.get_by_id(document_id=chunk.document_id)
  75. if document is None:
  76. continue
  77. document_cache[chunk.document_id] = document
  78. if not self._matches_filters(document=document, filters_json=payload.filters_json):
  79. continue
  80. keyword = bm25_score(
  81. payload.query, chunk.content_text,
  82. avg_doc_length=avg_doc_length, doc_count=doc_count, df=df_map,
  83. )
  84. vector = vector_scores_by_chunk_id.get(chunk.id)
  85. if vector is None:
  86. vector = cosine_similarity(query_embedding_result.embedding, chunk.embedding_json)
  87. rerank = (
  88. rerank_score(
  89. query=payload.query,
  90. chunk_text=chunk.content_text,
  91. document_title=document.title,
  92. )
  93. if self.settings.retrieval_rerank_enabled
  94. else 0.0
  95. )
  96. score = round(
  97. keyword * keyword_weight
  98. + vector * vector_weight
  99. + rerank * rerank_weight,
  100. 6,
  101. )
  102. scored.append((
  103. chunk,
  104. document,
  105. score,
  106. {
  107. "final_score": score,
  108. "keyword_score": round(keyword, 6),
  109. "vector_score": round(vector, 6),
  110. "rerank_score": round(rerank, 6),
  111. "retrieval_mode": retrieval_mode,
  112. "rerank_enabled": self.settings.retrieval_rerank_enabled,
  113. "candidate_limit": candidate_limit,
  114. "weights": {
  115. "keyword": keyword_weight,
  116. "vector": vector_weight,
  117. "rerank": rerank_weight,
  118. },
  119. "embedding_provider": query_embedding_result.provider,
  120. "embedding_model": query_embedding_result.model,
  121. "citation": {
  122. "document_id": document.id,
  123. "document_title": document.title,
  124. "source_uri": document.source_uri,
  125. "chunk_id": chunk.id,
  126. "chunk_index": chunk.chunk_index,
  127. },
  128. },
  129. ))
  130. scored.sort(key=lambda item: item[2], reverse=True)
  131. return scored[: payload.top_k]
  132. @staticmethod
  133. def _matches_filters(
  134. *,
  135. document: KnowledgeDocument,
  136. filters_json: dict[str, JSONValue],
  137. ) -> bool:
  138. source_type = filters_json.get("sourceType") or filters_json.get("source_type")
  139. if isinstance(source_type, str) and document.source_type != source_type:
  140. return False
  141. status = filters_json.get("status")
  142. if isinstance(status, str) and document.status != status:
  143. return False
  144. return True