chunking.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. """Structure-aware document chunking."""
  2. from __future__ import annotations
  3. import re
  4. from dataclasses import dataclass
  5. from core_shared import JSONValue
  6. from app.application.retrieval import tokenize
  7. _HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
  8. _CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```")
  9. _SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?。!?])\s+")
  10. _PARAGRAPH_SPLIT_RE = re.compile(r"\n{2,}")
  11. @dataclass(frozen=True)
  12. class ChunkPayload:
  13. chunk_index: int
  14. content_text: str
  15. token_count: int
  16. metadata_json: dict[str, JSONValue]
  17. def chunk_document(
  18. *,
  19. content_text: str,
  20. source_type: str,
  21. chunk_size: int,
  22. chunk_overlap: int,
  23. ) -> list[dict[str, JSONValue]]:
  24. """Dispatch to the appropriate chunker based on source_type."""
  25. normalized = source_type.strip().lower()
  26. text_for_chunking = raw_content or content_text
  27. if normalized in {"markdown", "md"}:
  28. chunks = _chunk_markdown(text_for_chunking, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
  29. elif normalized == "json":
  30. chunks = _chunk_json(content_text, chunk_size=chunk_size)
  31. else:
  32. chunks = _chunk_plain_text(content_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
  33. return [
  34. {
  35. "chunk_index": c.chunk_index,
  36. "content_text": c.content_text,
  37. "token_count": c.token_count,
  38. "metadata_json": c.metadata_json,
  39. }
  40. for c in chunks
  41. ]
  42. def _chunk_markdown(content: str, *, chunk_size: int, chunk_overlap: int) -> list[ChunkPayload]:
  43. sections = _split_markdown_by_headings(content)
  44. chunks: list[ChunkPayload] = []
  45. index = 0
  46. for heading_path, section_text in sections:
  47. section_text = section_text.strip()
  48. if not section_text:
  49. continue
  50. if len(section_text) <= chunk_size:
  51. chunks.append(_make_chunk(index, section_text, {"heading_path": heading_path, "chunk_type": "heading_section"}))
  52. index += 1
  53. continue
  54. sub_parts = _split_markdown_section(section_text)
  55. buffer = ""
  56. for part_text, part_type in sub_parts:
  57. if len(buffer) + len(part_text) + 1 > chunk_size and buffer:
  58. chunks.append(_make_chunk(index, buffer.strip(), {"heading_path": heading_path, "chunk_type": part_type}))
  59. index += 1
  60. overlap_text = buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
  61. buffer = overlap_text + "\n" + part_text
  62. else:
  63. buffer = buffer + "\n" + part_text if buffer else part_text
  64. if buffer.strip():
  65. chunks.append(_make_chunk(index, buffer.strip(), {"heading_path": heading_path, "chunk_type": "paragraph"}))
  66. index += 1
  67. return chunks
  68. def _split_markdown_by_headings(content: str) -> list[tuple[list[str], str]]:
  69. """Split markdown into (heading_path, section_text) tuples."""
  70. positions: list[tuple[int, int, str]] = []
  71. for match in _HEADING_RE.finditer(content):
  72. level = len(match.group(1))
  73. title = match.group(2).strip()
  74. positions.append((match.start(), level, title))
  75. if not positions:
  76. return [([], content)]
  77. sections: list[tuple[list[str], str]] = []
  78. active_headings: dict[int, str] = {}
  79. first_pos = positions[0][0]
  80. if first_pos > 0:
  81. preamble = content[:first_pos].strip()
  82. if preamble:
  83. sections.append(([], preamble))
  84. for i, (pos, level, title) in enumerate(positions):
  85. active_headings[level] = title
  86. for higher in list(active_headings):
  87. if higher > level:
  88. del active_headings[higher]
  89. path = [active_headings[l] for l in sorted(active_headings)]
  90. end = positions[i + 1][0] if i + 1 < len(positions) else len(content)
  91. section_text = content[pos:end]
  92. section_text = re.sub(r"^#{1,6}\s+.+$", "", section_text, count=1, flags=re.MULTILINE).strip()
  93. if section_text:
  94. sections.append((path, section_text))
  95. return sections
  96. def _split_markdown_section(text: str) -> list[tuple[str, str]]:
  97. """Split a markdown section into (text, chunk_type) parts."""
  98. parts: list[tuple[str, str]] = []
  99. last_end = 0
  100. for match in _CODE_BLOCK_RE.finditer(text):
  101. if match.start() > last_end:
  102. prose = text[last_end:match.start()].strip()
  103. if prose:
  104. for para in _PARAGRAPH_SPLIT_RE.split(prose):
  105. p = para.strip()
  106. if p:
  107. parts.append((p, "paragraph"))
  108. code = match.group()
  109. parts.append((code, "code_block"))
  110. last_end = match.end()
  111. if last_end < len(text):
  112. remaining = text[last_end:].strip()
  113. if remaining:
  114. for para in _PARAGRAPH_SPLIT_RE.split(remaining):
  115. p = para.strip()
  116. if p:
  117. parts.append((p, "paragraph"))
  118. return parts
  119. def _chunk_plain_text(content: str, *, chunk_size: int, chunk_overlap: int) -> list[ChunkPayload]:
  120. paragraphs = _PARAGRAPH_SPLIT_RE.split(content.strip())
  121. paragraphs = [p.strip() for p in paragraphs if p.strip()]
  122. if not paragraphs:
  123. return []
  124. chunks: list[ChunkPayload] = []
  125. buffer = ""
  126. index = 0
  127. for para in paragraphs:
  128. if len(para) > chunk_size:
  129. if buffer:
  130. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
  131. index += 1
  132. buffer = ""
  133. sentences = _split_sentences(para)
  134. sent_buffer = ""
  135. for sentence in sentences:
  136. if len(sent_buffer) + len(sentence) + 1 > chunk_size and sent_buffer:
  137. chunks.append(_make_chunk(index, sent_buffer.strip(), {"chunk_type": "sentence"}))
  138. index += 1
  139. overlap_text = sent_buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
  140. sent_buffer = overlap_text + " " + sentence
  141. else:
  142. sent_buffer = sent_buffer + " " + sentence if sent_buffer else sentence
  143. if sent_buffer.strip():
  144. buffer = sent_buffer.strip()
  145. elif len(buffer) + len(para) + 2 > chunk_size and buffer:
  146. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
  147. index += 1
  148. overlap_text = buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
  149. buffer = overlap_text + "\n\n" + para
  150. else:
  151. buffer = buffer + "\n\n" + para if buffer else para
  152. if buffer.strip():
  153. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
  154. return chunks
  155. def _chunk_json(content: str, *, chunk_size: int) -> list[ChunkPayload]:
  156. import json as json_lib
  157. try:
  158. data = json_lib.loads(content)
  159. except json_lib.JSONDecodeError:
  160. return _chunk_plain_text(content, chunk_size=chunk_size, chunk_overlap=0)
  161. if isinstance(data, dict):
  162. parts: list[tuple[str, str]] = []
  163. for key, value in data.items():
  164. line = f"{key}: {json_lib.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value}"
  165. parts.append((key, line))
  166. chunks: list[ChunkPayload] = []
  167. buffer = ""
  168. buffer_keys: list[str] = []
  169. index = 0
  170. for key, line in parts:
  171. if len(buffer) + len(line) + 1 > chunk_size and buffer:
  172. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "json_keys", "key_path": buffer_keys}))
  173. index += 1
  174. buffer = line
  175. buffer_keys = [key]
  176. else:
  177. buffer = buffer + "\n" + line if buffer else line
  178. buffer_keys.append(key)
  179. if buffer.strip():
  180. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "json_keys", "key_path": buffer_keys}))
  181. return chunks
  182. if isinstance(data, list):
  183. items_text = "\n".join(json_lib.dumps(item, ensure_ascii=False) for item in data)
  184. return _chunk_plain_text(items_text, chunk_size=chunk_size, chunk_overlap=0)
  185. return _chunk_plain_text(content, chunk_size=chunk_size, chunk_overlap=0)
  186. def _split_sentences(text: str) -> list[str]:
  187. parts = _SENTENCE_SPLIT_RE.split(text)
  188. return [p.strip() for p in parts if p.strip()]
  189. def _make_chunk(index: int, text: str, metadata: dict[str, JSONValue]) -> ChunkPayload:
  190. return ChunkPayload(
  191. chunk_index=index,
  192. content_text=text,
  193. token_count=len(tokenize(text)),
  194. metadata_json=metadata,
  195. )