retrieval.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import hashlib
  2. import math
  3. import re
  4. from collections import Counter
  5. from core_shared import JSONValue
  6. TOKEN_PATTERN = re.compile(r"[\w\u4e00-\u9fff]+", re.UNICODE)
  7. _K1 = 1.5
  8. _B = 0.75
  9. def split_text(text: str, *, chunk_size: int, chunk_overlap: int) -> list[str]:
  10. normalized_text = text.strip()
  11. if not normalized_text:
  12. return []
  13. safe_overlap = min(chunk_overlap, max(chunk_size - 1, 0))
  14. chunks: list[str] = []
  15. start = 0
  16. while start < len(normalized_text):
  17. end = min(start + chunk_size, len(normalized_text))
  18. chunks.append(normalized_text[start:end])
  19. if end == len(normalized_text):
  20. break
  21. start = end - safe_overlap
  22. return chunks
  23. def tokenize(text: str) -> list[str]:
  24. return [item.lower() for item in TOKEN_PATTERN.findall(text)]
  25. def build_hash_embedding(text: str, *, dimensions: int) -> list[float]:
  26. vector = [0.0 for _ in range(dimensions)]
  27. tokens = tokenize(text)
  28. if not tokens:
  29. return vector
  30. for token in tokens:
  31. digest = hashlib.sha256(token.encode("utf-8")).digest()
  32. index = int.from_bytes(digest[:4], "big") % dimensions
  33. sign = 1.0 if digest[4] % 2 == 0 else -1.0
  34. vector[index] += sign
  35. norm = math.sqrt(sum(item * item for item in vector))
  36. if norm == 0:
  37. return vector
  38. return [round(item / norm, 6) for item in vector]
  39. def cosine_similarity(left: list[float] | None, right: list[float] | None) -> float:
  40. if not left or not right or len(left) != len(right):
  41. return 0.0
  42. left_norm = math.sqrt(sum(item * item for item in left))
  43. right_norm = math.sqrt(sum(item * item for item in right))
  44. if left_norm == 0 or right_norm == 0:
  45. return 0.0
  46. return sum(a * b for a, b in zip(left, right, strict=True)) / (left_norm * right_norm)
  47. def keyword_score(query: str, text: str) -> float:
  48. """Backward-compatible wrapper around bm25_score with fallback stats."""
  49. query_tokens = tokenize(query)
  50. if not query_tokens:
  51. return 0.0
  52. text_tokens = tokenize(text)
  53. if not text_tokens:
  54. return 0.0
  55. doc_length = len(text_tokens)
  56. avg_doc_length = float(doc_length) or 1.0
  57. doc_count = 1
  58. text_counts = Counter(text_tokens)
  59. df: dict[str, int] = {token: 1 for token in text_counts}
  60. return bm25_score(query, text, avg_doc_length=avg_doc_length, doc_count=doc_count, df=df)
  61. def bm25_score(
  62. query: str,
  63. text: str,
  64. *,
  65. avg_doc_length: float,
  66. doc_count: int,
  67. df: dict[str, int],
  68. ) -> float:
  69. """Standard BM25 scoring. k1=1.5, b=0.75."""
  70. query_tokens = tokenize(query)
  71. if not query_tokens:
  72. return 0.0
  73. text_counts = Counter(tokenize(text))
  74. doc_length = sum(text_counts.values())
  75. if not text_counts:
  76. return 0.0
  77. score = 0.0
  78. for token in set(query_tokens):
  79. tf = text_counts.get(token, 0)
  80. if tf == 0:
  81. continue
  82. idf_numerator = doc_count - df.get(token, 0) + 0.5
  83. idf_denominator = df.get(token, 0) + 0.5
  84. if idf_denominator <= 0:
  85. continue
  86. idf = math.log((idf_numerator / idf_denominator) + 1.0)
  87. tf_norm = (tf * (_K1 + 1)) / (tf + _K1 * (1 - _B + _B * doc_length / max(avg_doc_length, 1.0)))
  88. score += idf * tf_norm
  89. return max(score, 0.0)
  90. def compute_bm25_stats(chunk_texts: list[str]) -> tuple[float, int, dict[str, int]]:
  91. """Compute average doc length, total doc count, and document frequency map for BM25."""
  92. if not chunk_texts:
  93. return 0.0, 0, {}
  94. total_length = 0
  95. df: dict[str, int] = {}
  96. for text in chunk_texts:
  97. tokens = set(tokenize(text))
  98. total_length += len(tokens)
  99. for token in tokens:
  100. df[token] = df.get(token, 0) + 1
  101. avg_doc_length = total_length / len(chunk_texts)
  102. return avg_doc_length, len(chunk_texts), df
  103. def rerank_score(*, query: str, chunk_text: str, document_title: str | None = None) -> float:
  104. query_tokens = tokenize(query)
  105. if not query_tokens:
  106. return 0.0
  107. chunk_tokens = tokenize(chunk_text)
  108. title_tokens = tokenize(document_title or "")
  109. if not chunk_tokens and not title_tokens:
  110. return 0.0
  111. unique_query_tokens = set(query_tokens)
  112. chunk_token_set = set(chunk_tokens)
  113. title_token_set = set(title_tokens)
  114. coverage = len(unique_query_tokens & chunk_token_set) / len(unique_query_tokens)
  115. title_bonus = min(len(unique_query_tokens & title_token_set) / len(unique_query_tokens), 1.0)
  116. phrase_bonus = 1.0 if query.lower() in chunk_text.lower() else 0.0
  117. density = sum(1 for token in chunk_tokens if token in unique_query_tokens) / max(
  118. len(chunk_tokens),
  119. 1)
  120. return min(coverage * 0.55 + title_bonus * 0.2 + phrase_bonus * 0.15 + density * 0.1, 1.0)
  121. def stable_content_hash(text: str) -> str:
  122. return hashlib.sha256(text.encode("utf-8")).hexdigest()
  123. def build_chunk_payloads(
  124. *,
  125. content_text: str,
  126. chunk_size: int,
  127. chunk_overlap: int) -> list[dict[str, JSONValue]]:
  128. chunks = split_text(
  129. content_text,
  130. chunk_size=chunk_size,
  131. chunk_overlap=chunk_overlap)
  132. payloads: list[dict[str, JSONValue]] = []
  133. for index, chunk_text in enumerate(chunks):
  134. payloads.append(
  135. {
  136. "chunk_index": index,
  137. "content_text": chunk_text,
  138. "token_count": len(tokenize(chunk_text)),
  139. "embedding_model": None,
  140. "embedding_json": None,
  141. "metadata_json": {},
  142. }
  143. )
  144. return payloads