services.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from core_shared import JSONValue
  2. from app.application.document_parsers import (
  3. DocumentParseError,
  4. ParsedDocument,
  5. parse_document_content,
  6. )
  7. from app.application.embeddings import EmbeddingService
  8. from app.application.retrieval import (
  9. build_chunk_payloads,
  10. cosine_similarity,
  11. keyword_score,
  12. rerank_score,
  13. stable_content_hash,
  14. )
  15. from app.bootstrap.settings import KnowledgeServiceSettings
  16. from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument
  17. from app.domain.repositories import (
  18. KnowledgeBaseRepository,
  19. KnowledgeChunkRepository,
  20. KnowledgeDocumentRepository,
  21. )
  22. from app.schemas.knowledge import (
  23. KnowledgeBaseCreateRequest,
  24. KnowledgeBaseStatusUpdateRequest,
  25. KnowledgeDocumentCreateRequest,
  26. KnowledgeDocumentParseRequest,
  27. KnowledgeSearchRequest,
  28. )
  29. class KnowledgeApplicationService:
  30. def __init__(
  31. self,
  32. *,
  33. settings: KnowledgeServiceSettings,
  34. base_repository: KnowledgeBaseRepository,
  35. document_repository: KnowledgeDocumentRepository,
  36. chunk_repository: KnowledgeChunkRepository,
  37. ) -> None:
  38. self.settings = settings
  39. self.base_repository = base_repository
  40. self.document_repository = document_repository
  41. self.chunk_repository = chunk_repository
  42. self.embedding_service = EmbeddingService(settings=settings)
  43. def create_base(self, payload: KnowledgeBaseCreateRequest) -> KnowledgeBase:
  44. return self.base_repository.create(
  45. tenant_id=payload.tenant_id,
  46. code=payload.code,
  47. name=payload.name,
  48. description=payload.description,
  49. metadata_json=payload.metadata_json,
  50. )
  51. def list_bases(self, *, tenant_id: str) -> list[KnowledgeBase]:
  52. return self.base_repository.list_by_tenant(tenant_id=tenant_id)
  53. def update_base_status(
  54. self,
  55. *,
  56. knowledge_base_id: str,
  57. payload: KnowledgeBaseStatusUpdateRequest,
  58. ) -> KnowledgeBase | None:
  59. return self.base_repository.update_status(
  60. tenant_id=payload.tenant_id,
  61. knowledge_base_id=knowledge_base_id,
  62. status=payload.status,
  63. )
  64. def create_document(
  65. self,
  66. payload: KnowledgeDocumentCreateRequest,
  67. ) -> tuple[KnowledgeDocument, list[KnowledgeChunk]]:
  68. knowledge_base = self.base_repository.get_by_id(
  69. tenant_id=payload.tenant_id,
  70. knowledge_base_id=payload.knowledge_base_id,
  71. )
  72. if knowledge_base is None:
  73. raise ValueError(f"knowledge base not found: {payload.knowledge_base_id}")
  74. parsed = self.parse_document(
  75. KnowledgeDocumentParseRequest(
  76. source_type=payload.source_type,
  77. source_uri=payload.source_uri,
  78. content_text=payload.content_text,
  79. content_base64=payload.content_base64,
  80. )
  81. )
  82. metadata_json = {
  83. **payload.metadata_json,
  84. "parser_metadata": parsed.metadata_json,
  85. }
  86. document = self.document_repository.create(
  87. tenant_id=payload.tenant_id,
  88. knowledge_base_id=payload.knowledge_base_id,
  89. title=payload.title,
  90. source_type=parsed.source_type,
  91. source_uri=payload.source_uri,
  92. content_text=parsed.content_text,
  93. content_hash=stable_content_hash(parsed.content_text),
  94. metadata_json=metadata_json,
  95. )
  96. chunks = self._index_document(
  97. document=document,
  98. content_text=parsed.content_text,
  99. chunk_size=payload.chunk_size,
  100. chunk_overlap=payload.chunk_overlap,
  101. )
  102. indexed_document = self.document_repository.update_status(
  103. document_id=document.id,
  104. status="indexed",
  105. )
  106. return indexed_document or document, chunks
  107. def parse_document(self, payload: KnowledgeDocumentParseRequest) -> ParsedDocument:
  108. try:
  109. return parse_document_content(
  110. source_type=payload.source_type,
  111. content_text=payload.content_text,
  112. content_base64=payload.content_base64,
  113. source_uri=payload.source_uri,
  114. )
  115. except DocumentParseError:
  116. raise
  117. def list_documents(
  118. self,
  119. *,
  120. tenant_id: str,
  121. knowledge_base_id: str,
  122. ) -> list[KnowledgeDocument]:
  123. return self.document_repository.list_by_base(
  124. tenant_id=tenant_id,
  125. knowledge_base_id=knowledge_base_id,
  126. )
  127. def search(
  128. self,
  129. payload: KnowledgeSearchRequest,
  130. ) -> list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]]:
  131. document_cache: dict[str, KnowledgeDocument] = {}
  132. query_embedding_result = self.embedding_service.embed_text(payload.query)
  133. candidate_limit = max(
  134. payload.top_k * max(self.settings.retrieval_candidate_multiplier, 1),
  135. payload.top_k,
  136. )
  137. vector_candidates = self.chunk_repository.search_by_vector(
  138. tenant_id=payload.tenant_id,
  139. knowledge_base_id=payload.knowledge_base_id,
  140. embedding=query_embedding_result.embedding,
  141. limit=candidate_limit,
  142. )
  143. if vector_candidates:
  144. chunks = [chunk for chunk, _ in vector_candidates]
  145. vector_scores_by_chunk_id = {
  146. chunk.id: score for chunk, score in vector_candidates
  147. }
  148. retrieval_mode = "pgvector-hybrid"
  149. else:
  150. chunks = self.chunk_repository.list_by_base(
  151. tenant_id=payload.tenant_id,
  152. knowledge_base_id=payload.knowledge_base_id,
  153. )
  154. vector_scores_by_chunk_id = {}
  155. retrieval_mode = "hybrid"
  156. scored: list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]] = []
  157. for chunk in chunks:
  158. document = document_cache.get(chunk.document_id)
  159. if document is None:
  160. document = self.document_repository.get_by_id(
  161. tenant_id=payload.tenant_id,
  162. document_id=chunk.document_id,
  163. )
  164. if document is None:
  165. continue
  166. document_cache[chunk.document_id] = document
  167. if not self._matches_filters(document=document, filters_json=payload.filters_json):
  168. continue
  169. keyword = keyword_score(payload.query, chunk.content_text)
  170. vector = vector_scores_by_chunk_id.get(chunk.id)
  171. if vector is None:
  172. vector = cosine_similarity(query_embedding_result.embedding, chunk.embedding_json)
  173. rerank = (
  174. rerank_score(
  175. query=payload.query,
  176. chunk_text=chunk.content_text,
  177. document_title=document.title,
  178. )
  179. if self.settings.retrieval_rerank_enabled
  180. else 0.0
  181. )
  182. score = round(
  183. keyword * self.settings.retrieval_keyword_weight
  184. + vector * self.settings.retrieval_vector_weight
  185. + rerank * self.settings.retrieval_rerank_weight,
  186. 6,
  187. )
  188. scored.append(
  189. (
  190. chunk,
  191. document,
  192. score,
  193. {
  194. "final_score": score,
  195. "keyword_score": round(keyword, 6),
  196. "vector_score": round(vector, 6),
  197. "rerank_score": round(rerank, 6),
  198. "retrieval_mode": retrieval_mode,
  199. "rerank_enabled": self.settings.retrieval_rerank_enabled,
  200. "candidate_limit": candidate_limit,
  201. "weights": {
  202. "keyword": self.settings.retrieval_keyword_weight,
  203. "vector": self.settings.retrieval_vector_weight,
  204. "rerank": self.settings.retrieval_rerank_weight,
  205. },
  206. "embedding_provider": query_embedding_result.provider,
  207. "embedding_model": query_embedding_result.model,
  208. "citation": {
  209. "document_id": document.id,
  210. "document_title": document.title,
  211. "source_uri": document.source_uri,
  212. "chunk_id": chunk.id,
  213. "chunk_index": chunk.chunk_index,
  214. },
  215. },
  216. )
  217. )
  218. scored.sort(key=lambda item: item[2], reverse=True)
  219. return scored[: payload.top_k]
  220. def _index_document(
  221. self,
  222. *,
  223. document: KnowledgeDocument,
  224. content_text: str,
  225. chunk_size: int | None,
  226. chunk_overlap: int | None,
  227. ) -> list[KnowledgeChunk]:
  228. chunk_payloads = build_chunk_payloads(
  229. content_text=content_text,
  230. chunk_size=chunk_size or self.settings.default_chunk_size,
  231. chunk_overlap=chunk_overlap or self.settings.default_chunk_overlap,
  232. )
  233. for chunk_payload in chunk_payloads:
  234. content_text = self._read_chunk_content(chunk_payload)
  235. embedding_result = self.embedding_service.embed_text(content_text)
  236. chunk_payload["embedding_model"] = embedding_result.model
  237. chunk_payload["embedding_json"] = embedding_result.embedding
  238. chunk_payload["metadata_json"] = {
  239. "embedding_provider": embedding_result.provider,
  240. }
  241. return self.chunk_repository.replace_document_chunks(
  242. tenant_id=document.tenant_id,
  243. knowledge_base_id=document.knowledge_base_id,
  244. document_id=document.id,
  245. chunks=chunk_payloads,
  246. )
  247. def _read_chunk_content(self, chunk_payload: dict[str, JSONValue]) -> str:
  248. value = chunk_payload.get("content_text")
  249. return value if isinstance(value, str) else ""
  250. def _matches_filters(
  251. self,
  252. *,
  253. document: KnowledgeDocument,
  254. filters_json: dict[str, JSONValue],
  255. ) -> bool:
  256. source_type = filters_json.get("source_type")
  257. if isinstance(source_type, str) and document.source_type != source_type:
  258. return False
  259. status = filters_json.get("status")
  260. if isinstance(status, str) and document.status != status:
  261. return False
  262. return True