services.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from core_shared import JSONValue
  2. from app.application.retrieval import (
  3. build_chunk_payloads,
  4. build_hash_embedding,
  5. cosine_similarity,
  6. keyword_score,
  7. stable_content_hash,
  8. )
  9. from app.bootstrap.settings import KnowledgeServiceSettings
  10. from app.db.models import KnowledgeBase, KnowledgeChunk, KnowledgeDocument
  11. from app.domain.repositories import (
  12. KnowledgeBaseRepository,
  13. KnowledgeChunkRepository,
  14. KnowledgeDocumentRepository,
  15. )
  16. from app.schemas.knowledge import (
  17. KnowledgeBaseCreateRequest,
  18. KnowledgeBaseStatusUpdateRequest,
  19. KnowledgeDocumentCreateRequest,
  20. KnowledgeSearchRequest,
  21. )
  22. class KnowledgeApplicationService:
  23. def __init__(
  24. self,
  25. *,
  26. settings: KnowledgeServiceSettings,
  27. base_repository: KnowledgeBaseRepository,
  28. document_repository: KnowledgeDocumentRepository,
  29. chunk_repository: KnowledgeChunkRepository,
  30. ) -> None:
  31. self.settings = settings
  32. self.base_repository = base_repository
  33. self.document_repository = document_repository
  34. self.chunk_repository = chunk_repository
  35. def create_base(self, payload: KnowledgeBaseCreateRequest) -> KnowledgeBase:
  36. return self.base_repository.create(
  37. tenant_id=payload.tenant_id,
  38. code=payload.code,
  39. name=payload.name,
  40. description=payload.description,
  41. metadata_json=payload.metadata_json,
  42. )
  43. def list_bases(self, *, tenant_id: str) -> list[KnowledgeBase]:
  44. return self.base_repository.list_by_tenant(tenant_id=tenant_id)
  45. def update_base_status(
  46. self,
  47. *,
  48. knowledge_base_id: str,
  49. payload: KnowledgeBaseStatusUpdateRequest,
  50. ) -> KnowledgeBase | None:
  51. return self.base_repository.update_status(
  52. tenant_id=payload.tenant_id,
  53. knowledge_base_id=knowledge_base_id,
  54. status=payload.status,
  55. )
  56. def create_document(
  57. self,
  58. payload: KnowledgeDocumentCreateRequest,
  59. ) -> tuple[KnowledgeDocument, list[KnowledgeChunk]]:
  60. knowledge_base = self.base_repository.get_by_id(
  61. tenant_id=payload.tenant_id,
  62. knowledge_base_id=payload.knowledge_base_id,
  63. )
  64. if knowledge_base is None:
  65. raise ValueError(f"knowledge base not found: {payload.knowledge_base_id}")
  66. document = self.document_repository.create(
  67. tenant_id=payload.tenant_id,
  68. knowledge_base_id=payload.knowledge_base_id,
  69. title=payload.title,
  70. source_type=payload.source_type,
  71. source_uri=payload.source_uri,
  72. content_text=payload.content_text,
  73. content_hash=stable_content_hash(payload.content_text),
  74. metadata_json=payload.metadata_json,
  75. )
  76. chunks = self._index_document(document=document, payload=payload)
  77. indexed_document = self.document_repository.update_status(
  78. document_id=document.id,
  79. status="indexed",
  80. )
  81. return indexed_document or document, chunks
  82. def list_documents(
  83. self,
  84. *,
  85. tenant_id: str,
  86. knowledge_base_id: str,
  87. ) -> list[KnowledgeDocument]:
  88. return self.document_repository.list_by_base(
  89. tenant_id=tenant_id,
  90. knowledge_base_id=knowledge_base_id,
  91. )
  92. def search(
  93. self,
  94. payload: KnowledgeSearchRequest,
  95. ) -> list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]]:
  96. chunks = self.chunk_repository.list_by_base(
  97. tenant_id=payload.tenant_id,
  98. knowledge_base_id=payload.knowledge_base_id,
  99. )
  100. document_cache: dict[str, KnowledgeDocument] = {}
  101. query_embedding = build_hash_embedding(
  102. payload.query,
  103. dimensions=self.settings.embedding_dimensions,
  104. )
  105. scored: list[tuple[KnowledgeChunk, KnowledgeDocument, float, dict[str, JSONValue]]] = []
  106. for chunk in chunks:
  107. document = document_cache.get(chunk.document_id)
  108. if document is None:
  109. document = self.document_repository.get_by_id(
  110. tenant_id=payload.tenant_id,
  111. document_id=chunk.document_id,
  112. )
  113. if document is None:
  114. continue
  115. document_cache[chunk.document_id] = document
  116. if not self._matches_filters(document=document, filters_json=payload.filters_json):
  117. continue
  118. keyword = keyword_score(payload.query, chunk.content_text)
  119. vector = cosine_similarity(query_embedding, chunk.embedding_json)
  120. score = round(keyword * 0.7 + vector * 0.3, 6)
  121. scored.append(
  122. (
  123. chunk,
  124. document,
  125. score,
  126. {
  127. "keyword_score": round(keyword, 6),
  128. "vector_score": round(vector, 6),
  129. "retrieval_mode": "hybrid-local",
  130. },
  131. )
  132. )
  133. scored.sort(key=lambda item: item[2], reverse=True)
  134. return scored[: payload.top_k]
  135. def _index_document(
  136. self,
  137. *,
  138. document: KnowledgeDocument,
  139. payload: KnowledgeDocumentCreateRequest,
  140. ) -> list[KnowledgeChunk]:
  141. chunk_payloads = build_chunk_payloads(
  142. content_text=payload.content_text,
  143. chunk_size=payload.chunk_size or self.settings.default_chunk_size,
  144. chunk_overlap=payload.chunk_overlap or self.settings.default_chunk_overlap,
  145. embedding_dimensions=self.settings.embedding_dimensions,
  146. embedding_model=self.settings.embedding_model,
  147. )
  148. return self.chunk_repository.replace_document_chunks(
  149. tenant_id=document.tenant_id,
  150. knowledge_base_id=document.knowledge_base_id,
  151. document_id=document.id,
  152. chunks=chunk_payloads,
  153. )
  154. def _matches_filters(
  155. self,
  156. *,
  157. document: KnowledgeDocument,
  158. filters_json: dict[str, JSONValue],
  159. ) -> bool:
  160. source_type = filters_json.get("source_type")
  161. if isinstance(source_type, str) and document.source_type != source_type:
  162. return False
  163. status = filters_json.get("status")
  164. if isinstance(status, str) and document.status != status:
  165. return False
  166. return True