| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- """Structure-aware document chunking."""
- from __future__ import annotations
- import re
- from dataclasses import dataclass
- from core_shared import JSONValue
- from app.application.retrieval import tokenize
- _HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
- _CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```")
- _SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?。!?])\s+")
- _PARAGRAPH_SPLIT_RE = re.compile(r"\n{2,}")
- @dataclass(frozen=True)
- class ChunkPayload:
- chunk_index: int
- content_text: str
- token_count: int
- metadata_json: dict[str, JSONValue]
- def chunk_document(
- *,
- content_text: str,
- source_type: str,
- chunk_size: int,
- chunk_overlap: int,
- ) -> list[dict[str, JSONValue]]:
- """Dispatch to the appropriate chunker based on source_type."""
- normalized = source_type.strip().lower()
- text_for_chunking = raw_content or content_text
- if normalized in {"markdown", "md"}:
- chunks = _chunk_markdown(text_for_chunking, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
- elif normalized == "json":
- chunks = _chunk_json(content_text, chunk_size=chunk_size)
- else:
- chunks = _chunk_plain_text(content_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
- return [
- {
- "chunk_index": c.chunk_index,
- "content_text": c.content_text,
- "token_count": c.token_count,
- "metadata_json": c.metadata_json,
- }
- for c in chunks
- ]
- def _chunk_markdown(content: str, *, chunk_size: int, chunk_overlap: int) -> list[ChunkPayload]:
- sections = _split_markdown_by_headings(content)
- chunks: list[ChunkPayload] = []
- index = 0
- for heading_path, section_text in sections:
- section_text = section_text.strip()
- if not section_text:
- continue
- if len(section_text) <= chunk_size:
- chunks.append(_make_chunk(index, section_text, {"heading_path": heading_path, "chunk_type": "heading_section"}))
- index += 1
- continue
- sub_parts = _split_markdown_section(section_text)
- buffer = ""
- for part_text, part_type in sub_parts:
- if len(buffer) + len(part_text) + 1 > chunk_size and buffer:
- chunks.append(_make_chunk(index, buffer.strip(), {"heading_path": heading_path, "chunk_type": part_type}))
- index += 1
- overlap_text = buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
- buffer = overlap_text + "\n" + part_text
- else:
- buffer = buffer + "\n" + part_text if buffer else part_text
- if buffer.strip():
- chunks.append(_make_chunk(index, buffer.strip(), {"heading_path": heading_path, "chunk_type": "paragraph"}))
- index += 1
- return chunks
- def _split_markdown_by_headings(content: str) -> list[tuple[list[str], str]]:
- """Split markdown into (heading_path, section_text) tuples."""
- positions: list[tuple[int, int, str]] = []
- for match in _HEADING_RE.finditer(content):
- level = len(match.group(1))
- title = match.group(2).strip()
- positions.append((match.start(), level, title))
- if not positions:
- return [([], content)]
- sections: list[tuple[list[str], str]] = []
- active_headings: dict[int, str] = {}
- first_pos = positions[0][0]
- if first_pos > 0:
- preamble = content[:first_pos].strip()
- if preamble:
- sections.append(([], preamble))
- for i, (pos, level, title) in enumerate(positions):
- active_headings[level] = title
- for higher in list(active_headings):
- if higher > level:
- del active_headings[higher]
- path = [active_headings[l] for l in sorted(active_headings)]
- end = positions[i + 1][0] if i + 1 < len(positions) else len(content)
- section_text = content[pos:end]
- section_text = re.sub(r"^#{1,6}\s+.+$", "", section_text, count=1, flags=re.MULTILINE).strip()
- if section_text:
- sections.append((path, section_text))
- return sections
- def _split_markdown_section(text: str) -> list[tuple[str, str]]:
- """Split a markdown section into (text, chunk_type) parts."""
- parts: list[tuple[str, str]] = []
- last_end = 0
- for match in _CODE_BLOCK_RE.finditer(text):
- if match.start() > last_end:
- prose = text[last_end:match.start()].strip()
- if prose:
- for para in _PARAGRAPH_SPLIT_RE.split(prose):
- p = para.strip()
- if p:
- parts.append((p, "paragraph"))
- code = match.group()
- parts.append((code, "code_block"))
- last_end = match.end()
- if last_end < len(text):
- remaining = text[last_end:].strip()
- if remaining:
- for para in _PARAGRAPH_SPLIT_RE.split(remaining):
- p = para.strip()
- if p:
- parts.append((p, "paragraph"))
- return parts
- def _chunk_plain_text(content: str, *, chunk_size: int, chunk_overlap: int) -> list[ChunkPayload]:
- paragraphs = _PARAGRAPH_SPLIT_RE.split(content.strip())
- paragraphs = [p.strip() for p in paragraphs if p.strip()]
- if not paragraphs:
- return []
- chunks: list[ChunkPayload] = []
- buffer = ""
- index = 0
- for para in paragraphs:
- if len(para) > chunk_size:
- if buffer:
- chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
- index += 1
- buffer = ""
- sentences = _split_sentences(para)
- sent_buffer = ""
- for sentence in sentences:
- if len(sent_buffer) + len(sentence) + 1 > chunk_size and sent_buffer:
- chunks.append(_make_chunk(index, sent_buffer.strip(), {"chunk_type": "sentence"}))
- index += 1
- overlap_text = sent_buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
- sent_buffer = overlap_text + " " + sentence
- else:
- sent_buffer = sent_buffer + " " + sentence if sent_buffer else sentence
- if sent_buffer.strip():
- buffer = sent_buffer.strip()
- elif len(buffer) + len(para) + 2 > chunk_size and buffer:
- chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
- index += 1
- overlap_text = buffer[-chunk_overlap:] if chunk_overlap > 0 else ""
- buffer = overlap_text + "\n\n" + para
- else:
- buffer = buffer + "\n\n" + para if buffer else para
- if buffer.strip():
- chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "paragraph"}))
- return chunks
- def _chunk_json(content: str, *, chunk_size: int) -> list[ChunkPayload]:
- import json as json_lib
- try:
- data = json_lib.loads(content)
- except json_lib.JSONDecodeError:
- return _chunk_plain_text(content, chunk_size=chunk_size, chunk_overlap=0)
- if isinstance(data, dict):
- parts: list[tuple[str, str]] = []
- for key, value in data.items():
- line = f"{key}: {json_lib.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value}"
- parts.append((key, line))
- chunks: list[ChunkPayload] = []
- buffer = ""
- buffer_keys: list[str] = []
- index = 0
- for key, line in parts:
- if len(buffer) + len(line) + 1 > chunk_size and buffer:
- chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "json_keys", "key_path": buffer_keys}))
- index += 1
- buffer = line
- buffer_keys = [key]
- else:
- buffer = buffer + "\n" + line if buffer else line
- buffer_keys.append(key)
- if buffer.strip():
- chunks.append(_make_chunk(index, buffer.strip(), {"chunk_type": "json_keys", "key_path": buffer_keys}))
- return chunks
- if isinstance(data, list):
- items_text = "\n".join(json_lib.dumps(item, ensure_ascii=False) for item in data)
- return _chunk_plain_text(items_text, chunk_size=chunk_size, chunk_overlap=0)
- return _chunk_plain_text(content, chunk_size=chunk_size, chunk_overlap=0)
- def _split_sentences(text: str) -> list[str]:
- parts = _SENTENCE_SPLIT_RE.split(text)
- return [p.strip() for p in parts if p.strip()]
- def _make_chunk(index: int, text: str, metadata: dict[str, JSONValue]) -> ChunkPayload:
- return ChunkPayload(
- chunk_index=index,
- content_text=text,
- token_count=len(tokenize(text)),
- metadata_json=metadata,
- )
|