retrieval.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import hashlib
  2. import math
  3. import re
  4. from collections import Counter
  5. from core_shared import JSONValue
  6. TOKEN_PATTERN = re.compile(r"[\w\u4e00-\u9fff]+", re.UNICODE)
  7. def split_text(text: str, *, chunk_size: int, chunk_overlap: int) -> list[str]:
  8. normalized_text = text.strip()
  9. if not normalized_text:
  10. return []
  11. safe_overlap = min(chunk_overlap, max(chunk_size - 1, 0))
  12. chunks: list[str] = []
  13. start = 0
  14. while start < len(normalized_text):
  15. end = min(start + chunk_size, len(normalized_text))
  16. chunks.append(normalized_text[start:end])
  17. if end == len(normalized_text):
  18. break
  19. start = end - safe_overlap
  20. return chunks
  21. def tokenize(text: str) -> list[str]:
  22. return [item.lower() for item in TOKEN_PATTERN.findall(text)]
  23. def build_hash_embedding(text: str, *, dimensions: int) -> list[float]:
  24. vector = [0.0 for _ in range(dimensions)]
  25. tokens = tokenize(text)
  26. if not tokens:
  27. return vector
  28. for token in tokens:
  29. digest = hashlib.sha256(token.encode("utf-8")).digest()
  30. index = int.from_bytes(digest[:4], "big") % dimensions
  31. sign = 1.0 if digest[4] % 2 == 0 else -1.0
  32. vector[index] += sign
  33. norm = math.sqrt(sum(item * item for item in vector))
  34. if norm == 0:
  35. return vector
  36. return [round(item / norm, 6) for item in vector]
  37. def cosine_similarity(left: list[float] | None, right: list[float] | None) -> float:
  38. if not left or not right or len(left) != len(right):
  39. return 0.0
  40. left_norm = math.sqrt(sum(item * item for item in left))
  41. right_norm = math.sqrt(sum(item * item for item in right))
  42. if left_norm == 0 or right_norm == 0:
  43. return 0.0
  44. return sum(a * b for a, b in zip(left, right, strict=True)) / (left_norm * right_norm)
  45. def keyword_score(query: str, text: str) -> float:
  46. query_tokens = tokenize(query)
  47. if not query_tokens:
  48. return 0.0
  49. text_counts = Counter(tokenize(text))
  50. if not text_counts:
  51. return 0.0
  52. matched = sum(1 for token in query_tokens if token in text_counts)
  53. frequency = sum(text_counts.get(token, 0) for token in query_tokens)
  54. return matched / len(set(query_tokens)) + min(frequency / 20.0, 1.0)
  55. def rerank_score(*, query: str, chunk_text: str, document_title: str | None = None) -> float:
  56. query_tokens = tokenize(query)
  57. if not query_tokens:
  58. return 0.0
  59. chunk_tokens = tokenize(chunk_text)
  60. title_tokens = tokenize(document_title or "")
  61. if not chunk_tokens and not title_tokens:
  62. return 0.0
  63. unique_query_tokens = set(query_tokens)
  64. chunk_token_set = set(chunk_tokens)
  65. title_token_set = set(title_tokens)
  66. coverage = len(unique_query_tokens & chunk_token_set) / len(unique_query_tokens)
  67. title_bonus = min(len(unique_query_tokens & title_token_set) / len(unique_query_tokens), 1.0)
  68. phrase_bonus = 1.0 if query.lower() in chunk_text.lower() else 0.0
  69. density = sum(1 for token in chunk_tokens if token in unique_query_tokens) / max(
  70. len(chunk_tokens),
  71. 1,
  72. )
  73. return min(coverage * 0.55 + title_bonus * 0.2 + phrase_bonus * 0.15 + density * 0.1, 1.0)
  74. def stable_content_hash(text: str) -> str:
  75. return hashlib.sha256(text.encode("utf-8")).hexdigest()
  76. def build_chunk_payloads(
  77. *,
  78. content_text: str,
  79. chunk_size: int,
  80. chunk_overlap: int,
  81. ) -> list[dict[str, JSONValue]]:
  82. chunks = split_text(
  83. content_text,
  84. chunk_size=chunk_size,
  85. chunk_overlap=chunk_overlap,
  86. )
  87. payloads: list[dict[str, JSONValue]] = []
  88. for index, chunk_text in enumerate(chunks):
  89. payloads.append(
  90. {
  91. "chunk_index": index,
  92. "content_text": chunk_text,
  93. "token_count": len(tokenize(chunk_text)),
  94. "embedding_model": None,
  95. "embedding_json": None,
  96. "metadata_json": {},
  97. }
  98. )
  99. return payloads