import hashlib import math import re from collections import Counter from core_shared import JSONValue TOKEN_PATTERN = re.compile(r"[\w\u4e00-\u9fff]+", re.UNICODE) _K1 = 1.5 _B = 0.75 def split_text(text: str, *, chunk_size: int, chunk_overlap: int) -> list[str]: normalized_text = text.strip() if not normalized_text: return [] safe_overlap = min(chunk_overlap, max(chunk_size - 1, 0)) chunks: list[str] = [] start = 0 while start < len(normalized_text): end = min(start + chunk_size, len(normalized_text)) chunks.append(normalized_text[start:end]) if end == len(normalized_text): break start = end - safe_overlap return chunks def tokenize(text: str) -> list[str]: return [item.lower() for item in TOKEN_PATTERN.findall(text)] def build_hash_embedding(text: str, *, dimensions: int) -> list[float]: vector = [0.0 for _ in range(dimensions)] tokens = tokenize(text) if not tokens: return vector for token in tokens: digest = hashlib.sha256(token.encode("utf-8")).digest() index = int.from_bytes(digest[:4], "big") % dimensions sign = 1.0 if digest[4] % 2 == 0 else -1.0 vector[index] += sign norm = math.sqrt(sum(item * item for item in vector)) if norm == 0: return vector return [round(item / norm, 6) for item in vector] def cosine_similarity(left: list[float] | None, right: list[float] | None) -> float: if not left or not right or len(left) != len(right): return 0.0 left_norm = math.sqrt(sum(item * item for item in left)) right_norm = math.sqrt(sum(item * item for item in right)) if left_norm == 0 or right_norm == 0: return 0.0 return sum(a * b for a, b in zip(left, right, strict=True)) / (left_norm * right_norm) def keyword_score(query: str, text: str) -> float: """Backward-compatible wrapper around bm25_score with fallback stats.""" query_tokens = tokenize(query) if not query_tokens: return 0.0 text_tokens = tokenize(text) if not text_tokens: return 0.0 doc_length = len(text_tokens) avg_doc_length = float(doc_length) or 1.0 doc_count = 1 text_counts = Counter(text_tokens) df: dict[str, int] = {token: 1 for token in text_counts} return bm25_score(query, text, avg_doc_length=avg_doc_length, doc_count=doc_count, df=df) def bm25_score( query: str, text: str, *, avg_doc_length: float, doc_count: int, df: dict[str, int], ) -> float: """Standard BM25 scoring. k1=1.5, b=0.75.""" query_tokens = tokenize(query) if not query_tokens: return 0.0 text_counts = Counter(tokenize(text)) doc_length = sum(text_counts.values()) if not text_counts: return 0.0 score = 0.0 for token in set(query_tokens): tf = text_counts.get(token, 0) if tf == 0: continue idf_numerator = doc_count - df.get(token, 0) + 0.5 idf_denominator = df.get(token, 0) + 0.5 if idf_denominator <= 0: continue idf = math.log((idf_numerator / idf_denominator) + 1.0) tf_norm = (tf * (_K1 + 1)) / (tf + _K1 * (1 - _B + _B * doc_length / max(avg_doc_length, 1.0))) score += idf * tf_norm return max(score, 0.0) def compute_bm25_stats(chunk_texts: list[str]) -> tuple[float, int, dict[str, int]]: """Compute average doc length, total doc count, and document frequency map for BM25.""" if not chunk_texts: return 0.0, 0, {} total_length = 0 df: dict[str, int] = {} for text in chunk_texts: tokens = set(tokenize(text)) total_length += len(tokens) for token in tokens: df[token] = df.get(token, 0) + 1 avg_doc_length = total_length / len(chunk_texts) return avg_doc_length, len(chunk_texts), df def rerank_score(*, query: str, chunk_text: str, document_title: str | None = None) -> float: query_tokens = tokenize(query) if not query_tokens: return 0.0 chunk_tokens = tokenize(chunk_text) title_tokens = tokenize(document_title or "") if not chunk_tokens and not title_tokens: return 0.0 unique_query_tokens = set(query_tokens) chunk_token_set = set(chunk_tokens) title_token_set = set(title_tokens) coverage = len(unique_query_tokens & chunk_token_set) / len(unique_query_tokens) title_bonus = min(len(unique_query_tokens & title_token_set) / len(unique_query_tokens), 1.0) phrase_bonus = 1.0 if query.lower() in chunk_text.lower() else 0.0 density = sum(1 for token in chunk_tokens if token in unique_query_tokens) / max( len(chunk_tokens), 1) return min(coverage * 0.55 + title_bonus * 0.2 + phrase_bonus * 0.15 + density * 0.1, 1.0) def stable_content_hash(text: str) -> str: return hashlib.sha256(text.encode("utf-8")).hexdigest() def build_chunk_payloads( *, content_text: str, chunk_size: int, chunk_overlap: int) -> list[dict[str, JSONValue]]: chunks = split_text( content_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap) payloads: list[dict[str, JSONValue]] = [] for index, chunk_text in enumerate(chunks): payloads.append( { "chunk_index": index, "content_text": chunk_text, "token_count": len(tokenize(chunk_text)), "embedding_model": None, "embedding_json": None, "metadata_json": {}, } ) return payloads