chunking.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """Structure-aware document chunking."""
  2. from __future__ import annotations
  3. import re
  4. from dataclasses import dataclass
  5. from core_shared import JSONValue
  6. from app.application.retrieval import tokenize
  7. _HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
  8. _CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```")
  9. _SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?。!?])\s+")
  10. _PARAGRAPH_SPLIT_RE = re.compile(r"\n{2,}")
  11. @dataclass(frozen=True)
  12. class ChunkPayload:
  13. chunk_index: int
  14. content_text: str
  15. token_count: int
  16. metadata_json: dict[str, JSONValue]
  17. def chunk_document(
  18. *,
  19. content_text: str,
  20. source_type: str,
  21. chunk_size: int,
  22. chunk_overlap: int,
  23. ) -> list[dict[str, JSONValue]]:
  24. """Dispatch to the appropriate chunker based on source_type."""
  25. normalized = source_type.strip().lower()
  26. if normalized in {"markdown", "md"}:
  27. chunks = _chunk_markdown(content_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
  28. elif normalized == "json":
  29. chunks = _chunk_json(content_text, chunk_size=chunk_size)
  30. else:
  31. chunks = _chunk_plain_text(content_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
  32. return [
  33. {
  34. "chunk_index": c.chunk_index,
  35. "content_text": c.content_text,
  36. "token_count": c.token_count,
  37. "metadata_json": c.metadata_json,
  38. }
  39. for c in chunks
  40. ]
  41. def _chunk_markdown(content: str, *, chunk_size: int, chunk_overlap: int) -> list[ChunkPayload]:
  42. sections = _split_markdown_by_headings(content)
  43. chunks: list[ChunkPayload] = []
  44. index = 0
  45. for heading_path, section_text in sections:
  46. section_text = section_text.strip()
  47. if not section_text:
  48. continue
  49. if len(section_text) <= chunk_size:
  50. chunks.append(_make_chunk(index, section_text, {"heading_path": heading_path, "chunk_type": "heading_section"}))
  51. index += 1
  52. continue
  53. sub_parts = _split_markdown_section(section_text)
  54. buffer = ""
  55. for part_text, part_type in sub_parts:
  56. if len(buffer) + len(part_text) + 1 > chunk_size and buffer:
  57. chunks.append(_make_chunk(index, buffer.strip(), {"heading_path": heading_path, "chunk_type": part_type}))
  58. index += 1
  59. overlap_text = buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
  60. buffer = overlap_text + "\n" + part_text
  61. else:
  62. buffer = buffer + "\n" + part_text if buffer else part_text
  63. if buffer.strip():
  64. chunks.append(_make_chunk(index, buffer.strip(), {"heading_path": heading_path, "chunk_type": "paragraph"}))
  65. index += 1
  66. return chunks
  67. def _split_markdown_by_headings(content: str) -> list[tuple[list[str], str]]:
  68. """Split markdown into (heading_path, section_text) tuples."""
  69. positions: list[tuple[int, int, str]] = []
  70. for match in _HEADING_RE.finditer(content):
  71. level = len(match.group(1))
  72. title = match.group(2).strip()
  73. positions.append((match.start(), level, title))
  74. if not positions:
  75. return [([], content)]
  76. sections: list[tuple[list[str], str]] = []
  77. active_headings: dict[int, str] = {}
  78. first_pos = positions[0][0]
  79. if first_pos > 0:
  80. preamble = content[:first_pos].strip()
  81. if preamble:
  82. sections.append(([], preamble))
  83. for i, (pos, level, title) in enumerate(positions):
  84. active_headings[level] = title
  85. for higher in list(active_headings):
  86. if higher > level:
  87. del active_headings[higher]
  88. path = [active_headings[l] for l in sorted(active_headings)]
  89. end = positions[i + 1][0] if i + 1 < len(positions) else len(content)
  90. section_text = content[pos:end]
  91. section_text = re.sub(r"^#{1,6}\s+.+$", "", section_text, count=1, flags=re.MULTILINE).strip()
  92. if section_text:
  93. sections.append((path, section_text))
  94. return sections
  95. def _split_markdown_section(text: str) -> list[tuple[str, str]]:
  96. """Split a markdown section into (text, chunk_type) parts."""
  97. parts: list[tuple[str, str]] = []
  98. last_end = 0
  99. for match in _CODE_BLOCK_RE.finditer(text):
  100. if match.start() > last_end:
  101. prose = text[last_end:match.start()].strip()
  102. if prose:
  103. for para in _PARAGRAPH_SPLIT_RE.split(prose):
  104. p = para.strip()
  105. if p:
  106. parts.append((p, "paragraph"))
  107. code = match.group()
  108. parts.append((code, "code_block"))
  109. last_end = match.end()
  110. if last_end < len(text):
  111. remaining = text[last_end:].strip()
  112. if remaining:
  113. for para in _PARAGRAPH_SPLIT_RE.split(remaining):
  114. p = para.strip()
  115. if p:
  116. parts.append((p, "paragraph"))
  117. return parts
  118. def _chunk_plain_text(content: str, *, chunk_size: int, chunk_overlap: int) -> list[ChunkPayload]:
  119. paragraphs = _PARAGRAPH_SPLIT_RE.split(content.strip())
  120. paragraphs = [p.strip() for p in paragraphs if p.strip()]
  121. if not paragraphs:
  122. return []
  123. chunks: list[ChunkPayload] = []
  124. buffer = ""
  125. index = 0
  126. for para in paragraphs:
  127. if len(para) > chunk_size:
  128. if buffer:
  129. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
  130. index += 1
  131. buffer = ""
  132. sentences = _split_sentences(para)
  133. sent_buffer = ""
  134. for sentence in sentences:
  135. if len(sent_buffer) + len(sentence) + 1 > chunk_size and sent_buffer:
  136. chunks.append(_make_chunk(index, sent_buffer.strip(), {"chunk_type": "sentence"}))
  137. index += 1
  138. overlap_text = sent_buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
  139. sent_buffer = overlap_text + " " + sentence
  140. else:
  141. sent_buffer = sent_buffer + " " + sentence if sent_buffer else sentence
  142. if sent_buffer.strip():
  143. buffer = sent_buffer.strip()
  144. elif len(buffer) + len(para) + 2 > chunk_size and buffer:
  145. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
  146. index += 1
  147. overlap_text = buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
  148. buffer = overlap_text + "\n\n" + para
  149. else:
  150. buffer = buffer + "\n\n" + para if buffer else para
  151. if buffer.strip():
  152. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
  153. return chunks
  154. def _chunk_json(content: str, *, chunk_size: int) -> list[ChunkPayload]:
  155. import json as json_lib
  156. try:
  157. data = json_lib.loads(content)
  158. except json_lib.JSONDecodeError:
  159. return _chunk_plain_text(content, chunk_size=chunk_size, chunk_overlap=0)
  160. if isinstance(data, dict):
  161. parts: list[tuple[str, str]] = []
  162. for key, value in data.items():
  163. line = f"{key}: {json_lib.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value}"
  164. parts.append((key, line))
  165. chunks: list[ChunkPayload] = []
  166. buffer = ""
  167. buffer_keys: list[str] = []
  168. index = 0
  169. for key, line in parts:
  170. if len(buffer) + len(line) + 1 > chunk_size and buffer:
  171. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "json_keys", "key_path": buffer_keys}))
  172. index += 1
  173. buffer = line
  174. buffer_keys = [key]
  175. else:
  176. buffer = buffer + "\n" + line if buffer else line
  177. buffer_keys.append(key)
  178. if buffer.strip():
  179. chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "json_keys", "key_path": buffer_keys}))
  180. return chunks
  181. if isinstance(data, list):
  182. items_text = "\n".join(json_lib.dumps(item, ensure_ascii=False) for item in data)
  183. return _chunk_plain_text(items_text, chunk_size=chunk_size, chunk_overlap=0)
  184. return _chunk_plain_text(content, chunk_size=chunk_size, chunk_overlap=0)
  185. def _split_sentences(text: str) -> list[str]:
  186. parts = _SENTENCE_SPLIT_RE.split(text)
  187. return [p.strip() for p in parts if p.strip()]
  188. def _make_chunk(index: int, text: str, metadata: dict[str, JSONValue]) -> ChunkPayload:
  189. return ChunkPayload(
  190. chunk_index=index,
  191. content_text=text,
  192. token_count=len(tokenize(text)),
  193. metadata_json=metadata,
  194. )