"""Structure-aware document chunking.""" from __future__ import annotations import re from dataclasses import dataclass from core_shared import JSONValue from app.application.retrieval import tokenize _HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE) _CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```") _SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?。!?])\s+") _PARAGRAPH_SPLIT_RE = re.compile(r"\n{2,}") @dataclass(frozen=True) class ChunkPayload: chunk_index: int content_text: str token_count: int metadata_json: dict[str, JSONValue] def chunk_document( *, content_text: str, source_type: str, chunk_size: int, chunk_overlap: int, ) -> list[dict[str, JSONValue]]: """Dispatch to the appropriate chunker based on source_type.""" normalized = source_type.strip().lower() text_for_chunking = raw_content or content_text if normalized in {"markdown", "md"}: chunks = _chunk_markdown(text_for_chunking, chunk_size=chunk_size, chunk_overlap=chunk_overlap) elif normalized == "json": chunks = _chunk_json(content_text, chunk_size=chunk_size) else: chunks = _chunk_plain_text(content_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap) return [ { "chunk_index": c.chunk_index, "content_text": c.content_text, "token_count": c.token_count, "metadata_json": c.metadata_json, } for c in chunks ] def _chunk_markdown(content: str, *, chunk_size: int, chunk_overlap: int) -> list[ChunkPayload]: sections = _split_markdown_by_headings(content) chunks: list[ChunkPayload] = [] index = 0 for heading_path, section_text in sections: section_text = section_text.strip() if not section_text: continue if len(section_text) <= chunk_size: chunks.append(_make_chunk(index, section_text, {"heading_path": heading_path, "chunk_type": "heading_section"})) index += 1 continue sub_parts = _split_markdown_section(section_text) buffer = "" for part_text, part_type in sub_parts: if len(buffer) + len(part_text) + 1 > chunk_size and buffer: chunks.append(_make_chunk(index, buffer.strip(), {"heading_path": heading_path, "chunk_type": part_type})) index += 1 overlap_text = buffer[-chunk_overlap:] if chunk_overlap > 0 else "" buffer = overlap_text + "\n" + part_text else: buffer = buffer + "\n" + part_text if buffer else part_text if buffer.strip(): chunks.append(_make_chunk(index, buffer.strip(), {"heading_path": heading_path, "chunk_type": "paragraph"})) index += 1 return chunks def _split_markdown_by_headings(content: str) -> list[tuple[list[str], str]]: """Split markdown into (heading_path, section_text) tuples.""" positions: list[tuple[int, int, str]] = [] for match in _HEADING_RE.finditer(content): level = len(match.group(1)) title = match.group(2).strip() positions.append((match.start(), level, title)) if not positions: return [([], content)] sections: list[tuple[list[str], str]] = [] active_headings: dict[int, str] = {} first_pos = positions[0][0] if first_pos > 0: preamble = content[:first_pos].strip() if preamble: sections.append(([], preamble)) for i, (pos, level, title) in enumerate(positions): active_headings[level] = title for higher in list(active_headings): if higher > level: del active_headings[higher] path = [active_headings[l] for l in sorted(active_headings)] end = positions[i + 1][0] if i + 1 < len(positions) else len(content) section_text = content[pos:end] section_text = re.sub(r"^#{1,6}\s+.+$", "", section_text, count=1, flags=re.MULTILINE).strip() if section_text: sections.append((path, section_text)) return sections def _split_markdown_section(text: str) -> list[tuple[str, str]]: """Split a markdown section into (text, chunk_type) parts.""" parts: list[tuple[str, str]] = [] last_end = 0 for match in _CODE_BLOCK_RE.finditer(text): if match.start() > last_end: prose = text[last_end:match.start()].strip() if prose: for para in _PARAGRAPH_SPLIT_RE.split(prose): p = para.strip() if p: parts.append((p, "paragraph")) code = match.group() parts.append((code, "code_block")) last_end = match.end() if last_end < len(text): remaining = text[last_end:].strip() if remaining: for para in _PARAGRAPH_SPLIT_RE.split(remaining): p = para.strip() if p: parts.append((p, "paragraph")) return parts def _chunk_plain_text(content: str, *, chunk_size: int, chunk_overlap: int) -> list[ChunkPayload]: paragraphs = _PARAGRAPH_SPLIT_RE.split(content.strip()) paragraphs = [p.strip() for p in paragraphs if p.strip()] if not paragraphs: return [] chunks: list[ChunkPayload] = [] buffer = "" index = 0 for para in paragraphs: if len(para) > chunk_size: if buffer: chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"})) index += 1 buffer = "" sentences = _split_sentences(para) sent_buffer = "" for sentence in sentences: if len(sent_buffer) + len(sentence) + 1 > chunk_size and sent_buffer: chunks.append(_make_chunk(index, sent_buffer.strip(), {"chunk_type": "sentence"})) index += 1 overlap_text = sent_buffer[-chunk_overlap:] if chunk_overlap > 0 else "" sent_buffer = overlap_text + " " + sentence else: sent_buffer = sent_buffer + " " + sentence if sent_buffer else sentence if sent_buffer.strip(): buffer = sent_buffer.strip() elif len(buffer) + len(para) + 2 > chunk_size and buffer: chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"})) index += 1 overlap_text = buffer[-chunk_overlap:] if chunk_overlap > 0 else "" buffer = overlap_text + "\n\n" + para else: buffer = buffer + "\n\n" + para if buffer else para if buffer.strip(): chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"})) return chunks def _chunk_json(content: str, *, chunk_size: int) -> list[ChunkPayload]: import json as json_lib try: data = json_lib.loads(content) except json_lib.JSONDecodeError: return _chunk_plain_text(content, chunk_size=chunk_size, chunk_overlap=0) if isinstance(data, dict): parts: list[tuple[str, str]] = [] for key, value in data.items(): line = f"{key}: {json_lib.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value}" parts.append((key, line)) chunks: list[ChunkPayload] = [] buffer = "" buffer_keys: list[str] = [] index = 0 for key, line in parts: if len(buffer) + len(line) + 1 > chunk_size and buffer: chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "json_keys", "key_path": buffer_keys})) index += 1 buffer = line buffer_keys = [key] else: buffer = buffer + "\n" + line if buffer else line buffer_keys.append(key) if buffer.strip(): chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "json_keys", "key_path": buffer_keys})) return chunks if isinstance(data, list): items_text = "\n".join(json_lib.dumps(item, ensure_ascii=False) for item in data) return _chunk_plain_text(items_text, chunk_size=chunk_size, chunk_overlap=0) return _chunk_plain_text(content, chunk_size=chunk_size, chunk_overlap=0) def _split_sentences(text: str) -> list[str]: parts = _SENTENCE_SPLIT_RE.split(text) return [p.strip() for p in parts if p.strip()] def _make_chunk(index: int, text: str, metadata: dict[str, JSONValue]) -> ChunkPayload: return ChunkPayload( chunk_index=index, content_text=text, token_count=len(tokenize(text)), metadata_json=metadata, )