proper chunk sizes for books

This commit is contained in:
Daniel O'Connell 2025-05-25 11:23:19 +02:00
parent eb69221999
commit e8070a3557
9 changed files with 178 additions and 88 deletions

View File

@ -6,8 +6,9 @@ logger = logging.getLogger(__name__)
# Chunking configuration # Chunking configuration
MAX_TOKENS = 32000 # VoyageAI max context window EMBEDDING_MAX_TOKENS = 32000 # VoyageAI max context window
OVERLAP_TOKENS = 200 # Default overlap between chunks DEFAULT_CHUNK_TOKENS = 512 # Optimal chunk size for semantic search
OVERLAP_TOKENS = 50 # Default overlap between chunks (10% of chunk size)
CHARS_PER_TOKEN = 4 CHARS_PER_TOKEN = 4
@ -23,7 +24,9 @@ def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN return len(s) // CHARS_PER_TOKEN
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]: def yield_word_chunks(
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
) -> Iterable[str]:
words = text.split() words = text.split()
if not words: if not words:
return return
@ -40,7 +43,7 @@ def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
yield current yield current
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]: def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[str]:
""" """
Yield text spans in priority order: paragraphs, sentences, words. Yield text spans in priority order: paragraphs, sentences, words.
Each span is guaranteed to be under max_tokens. Each span is guaranteed to be under max_tokens.
@ -76,14 +79,16 @@ def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
yield chunk yield chunk
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]: def chunk_text(
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS, overlap: int = OVERLAP_TOKENS
) -> Iterable[str]:
""" """
Split text into chunks respecting semantic boundaries while staying within token limits. Split text into chunks respecting semantic boundaries while staying within token limits.
Args: Args:
text: The text to chunk text: The text to chunk
max_tokens: Maximum tokens per chunk (default: VoyageAI max context) max_tokens: Maximum tokens per chunk (default: 512 for optimal semantic search)
overlap: Number of tokens to overlap between chunks (default: 200) overlap: Number of tokens to overlap between chunks (default: 50)
Returns: Returns:
List of text chunks List of text chunks
@ -111,9 +116,7 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
overlap_text = current[-overlap_chars:] overlap_text = current[-overlap_chars:]
clean_break = max( clean_break = max(
overlap_text.rfind(". "), overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ")
overlap_text.rfind("! "),
overlap_text.rfind("? ")
) )
if clean_break < 0: if clean_break < 0:

View File

@ -481,6 +481,7 @@ class BookSection(SourceItem):
backref="children", backref="children",
foreign_keys=[parent_section_id], foreign_keys=[parent_section_id],
) )
pages: list[str] = []
__mapper_args__ = {"polymorphic_identity": "book_section"} __mapper_args__ = {"polymorphic_identity": "book_section"}
__table_args__ = ( __table_args__ = (

View File

@ -7,18 +7,12 @@ import voyageai
from PIL import Image from PIL import Image
from memory.common import extract, settings from memory.common import extract, settings
from memory.common.chunker import chunk_text from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
from memory.common.db.models import Chunk from memory.common.db.models import Chunk
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Chunking configuration
MAX_TOKENS = 32000 # VoyageAI max context window
OVERLAP_TOKENS = 200 # Default overlap between chunks
CHARS_PER_TOKEN = 4
DistanceType = Literal["Cosine", "Dot", "Euclidean"] DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float] Vector = list[float]
@ -150,12 +144,13 @@ def embed_text(
texts: list[str], texts: list[str],
model: str = settings.TEXT_EMBEDDING_MODEL, model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document", input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]: ) -> list[Vector]:
chunks = [ chunks = [
c c
for text in texts for text in texts
if isinstance(text, str) if isinstance(text, str)
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
if c.strip() if c.strip()
] ]
if not chunks: if not chunks:
@ -179,11 +174,12 @@ def embed_mixed(
items: list[extract.MulitmodalChunk], items: list[extract.MulitmodalChunk],
model: str = settings.MIXED_EMBEDDING_MODEL, model: str = settings.MIXED_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document", input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]: ) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]: def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
if isinstance(item, str): if isinstance(item, str):
return [ return [
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip() c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
] ]
return [item] return [item]
@ -193,13 +189,17 @@ def embed_mixed(
def embed_page(page: extract.Page) -> list[Vector]: def embed_page(page: extract.Page) -> list[Vector]:
contents = page["contents"] contents = page["contents"]
chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS)
if all(isinstance(c, str) for c in contents): if all(isinstance(c, str) for c in contents):
return embed_text( return embed_text(
cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL cast(list[str], contents),
model=settings.TEXT_EMBEDDING_MODEL,
chunk_size=chunk_size,
) )
return embed_mixed( return embed_mixed(
cast(list[extract.MulitmodalChunk], contents), cast(list[extract.MulitmodalChunk], contents),
model=settings.MIXED_EMBEDDING_MODEL, model=settings.MIXED_EMBEDDING_MODEL,
chunk_size=chunk_size,
) )
@ -255,9 +255,10 @@ def embed(
mime_type: str, mime_type: str,
content: bytes | str | pathlib.Path, content: bytes | str | pathlib.Path,
metadata: dict[str, Any] = {}, metadata: dict[str, Any] = {},
chunk_size: int | None = None,
) -> tuple[str, list[Chunk]]: ) -> tuple[str, list[Chunk]]:
modality = get_modality(mime_type) modality = get_modality(mime_type)
pages = extract.extract_content(mime_type, content) pages = extract.extract_content(mime_type, content, chunk_size=chunk_size)
chunks = [ chunks = [
make_chunk(page, vector, metadata) make_chunk(page, vector, metadata)
for page in pages for page in pages
@ -266,13 +267,17 @@ def embed(
return modality, chunks return modality, chunks
def embed_image(file_path: pathlib.Path, texts: list[str]) -> Chunk: def embed_image(
file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None
) -> Chunk:
image = Image.open(file_path) image = Image.open(file_path)
mime_type = get_mimetype(image) mime_type = get_mimetype(image)
if mime_type is None: if mime_type is None:
raise ValueError("Unsupported image format") raise ValueError("Unsupported image format")
vector = embed_mixed([image] + texts)[0] vector = embed_mixed(
[image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS
)[0]
return Chunk( return Chunk(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),

View File

@ -3,7 +3,7 @@ import logging
import pathlib import pathlib
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Generator, Sequence, TypedDict, cast from typing import Any, Generator, NotRequired, Sequence, TypedDict, cast
import pymupdf # PyMuPDF import pymupdf # PyMuPDF
import pypandoc import pypandoc
@ -17,6 +17,8 @@ MulitmodalChunk = Image.Image | str
class Page(TypedDict): class Page(TypedDict):
contents: Sequence[MulitmodalChunk] contents: Sequence[MulitmodalChunk]
metadata: dict[str, Any] metadata: dict[str, Any]
# This is used to override the default chunk size for the page
chunk_size: NotRequired[int]
@contextmanager @contextmanager
@ -108,22 +110,28 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
return [{"contents": [cast(str, content)], "metadata": {}}] return [{"contents": [cast(str, content)], "metadata": {}}]
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]: def extract_content(
mime_type: str,
content: bytes | str | pathlib.Path,
chunk_size: int | None = None,
) -> list[Page]:
pages = []
logger.info(f"Extracting content from {mime_type}") logger.info(f"Extracting content from {mime_type}")
if mime_type == "application/pdf": if mime_type == "application/pdf":
return doc_to_images(content) pages = doc_to_images(content)
if mime_type in [ elif mime_type in [
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/msword", "application/msword",
]: ]:
logger.info(f"Extracting content from {content}") logger.info(f"Extracting content from {content}")
pages = extract_docx(content) pages = extract_docx(content)
logger.info(f"Extracted {len(pages)} pages from {content}") logger.info(f"Extracted {len(pages)} pages from {content}")
return pages elif mime_type.startswith("text/"):
if mime_type.startswith("text/"): pages = extract_text(content)
return extract_text(content) elif mime_type.startswith("image/"):
if mime_type.startswith("image/"): pages = extract_image(content)
return extract_image(content)
# Return empty list for unknown mime types if chunk_size:
return [] pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
return pages

View File

@ -13,7 +13,7 @@ class Section:
"""Represents a chapter or section in an ebook.""" """Represents a chapter or section in an ebook."""
title: str title: str
content: str pages: list[str]
number: int | None = None number: int | None = None
start_page: int | None = None start_page: int | None = None
end_page: int | None = None end_page: int | None = None
@ -74,13 +74,13 @@ def extract_epub_metadata(doc) -> dict[str, Any]:
return {key: value for key, value in doc.metadata.items() if value} return {key: value for key, value in doc.metadata.items() if value}
def get_pages(doc, start_page: int, end_page: int) -> str: def get_pages(doc, start_page: int, end_page: int) -> list[str]:
pages = [ pages = [
doc[page_num].get_text() doc[page_num].get_text()
for page_num in range(start_page, end_page + 1) for page_num in range(start_page, end_page + 1)
if 0 <= page_num < doc.page_count if 0 <= page_num < doc.page_count
] ]
return "\n".join(pages) return pages
def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | None: def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | None:
@ -96,7 +96,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section |
if not next_item: if not next_item:
return Section( return Section(
title=name, title=name,
content=get_pages(doc, page, doc.page_count), pages=get_pages(doc, page, doc.page_count),
number=section_num, number=section_num,
start_page=page, start_page=page,
end_page=doc.page_count, end_page=doc.page_count,
@ -110,7 +110,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section |
last_page = next_item[2] - 1 if next_item else doc.page_count last_page = next_item[2] - 1 if next_item else doc.page_count
return Section( return Section(
title=name, title=name,
content=get_pages(doc, page, last_page), pages=get_pages(doc, page, last_page),
number=section_num, number=section_num,
start_page=page, start_page=page,
end_page=last_page, end_page=last_page,
@ -125,7 +125,7 @@ def extract_sections(doc) -> list[Section]:
return [ return [
Section( Section(
title="Content", title="Content",
content=doc.get_text(), pages=get_pages(doc, 0, doc.page_count),
number=1, number=1,
start_page=0, start_page=0,
end_page=doc.page_count, end_page=doc.page_count,
@ -169,7 +169,9 @@ def parse_ebook(file_path: str | Path) -> Ebook:
sections = extract_sections(doc) sections = extract_sections(doc)
full_content = "" full_content = ""
if sections: if sections:
full_content = "".join(section.content for section in sections) full_content = "\n\n".join(
page for section in sections for page in section.pages
)
return Ebook( return Ebook(
title=title, title=title,

View File

@ -3,7 +3,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Iterable, cast from typing import Iterable, cast
from memory.common import embedding, qdrant, settings from memory.common import chunker, embedding, qdrant
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import Book, BookSection from memory.common.db.models import Book, BookSection
from memory.common.parsers.ebook import Ebook, parse_ebook, Section from memory.common.parsers.ebook import Ebook, parse_ebook, Section
@ -44,7 +44,8 @@ def section_processor(
level: int = 1, level: int = 1,
parent_key: tuple[int, int | None] | None = None, parent_key: tuple[int, int | None] | None = None,
): ):
if len(section.content.strip()) >= MIN_SECTION_LENGTH: content = "\n\n".join(section.pages).strip()
if len(content) >= MIN_SECTION_LENGTH:
sha256 = hashlib.sha256( sha256 = hashlib.sha256(
f"{book.id}:{section.title}:{section.start_page}".encode() f"{book.id}:{section.title}:{section.start_page}".encode()
).digest() ).digest()
@ -57,10 +58,11 @@ def section_processor(
start_page=section.start_page, start_page=section.start_page,
end_page=section.end_page, end_page=section.end_page,
parent_section_id=None, # Will be set after flush parent_section_id=None, # Will be set after flush
content=section.content, content=content,
sha256=sha256, sha256=sha256,
modality="book", modality="book",
tags=book.tags, tags=book.tags,
pages=section.pages,
) )
all_sections.append(book_section) all_sections.append(book_section)
@ -127,13 +129,28 @@ def embed_sections(all_sections: list[BookSection]) -> int:
"""Embed all sections and return count of successfully embedded sections.""" """Embed all sections and return count of successfully embedded sections."""
embedded_count = 0 embedded_count = 0
for section in all_sections: def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]:
try:
_, chunks = embedding.embed( _, chunks = embedding.embed(
"text/plain", "text/plain",
cast(str, section.content), text,
metadata=section.as_payload(), metadata=metadata,
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
) )
return chunks
for section in all_sections:
try:
section_chunks = embed_text(
cast(str, section.content), section.as_payload()
)
page_chunks = [
chunk
for i, page in enumerate(section.pages)
for chunk in embed_text(
page, section.as_payload() | {"page_number": i + section.start_page}
)
]
chunks = section_chunks + page_chunks
if chunks: if chunks:
section.chunks = chunks section.chunks = chunks

View File

@ -222,4 +222,7 @@ def qdrant():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_voyage_client(): def mock_voyage_client():
with patch.object(voyageai, "Client", autospec=True) as mock_client: with patch.object(voyageai, "Client", autospec=True) as mock_client:
yield mock_client() client = mock_client()
client.embed.return_value.embeddings = [[0.1] * 1024]
client.multimodal_embed.return_value.embeddings = [[0.1] * 1024]
yield client

View File

@ -83,20 +83,20 @@ def test_extract_epub_metadata(metadata_input, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"start_page,end_page,expected_content", "start_page,end_page,expected_content",
[ [
(0, 2, "Content of page 0\nContent of page 1\nContent of page 2"), (0, 2, ["Content of page 0", "Content of page 1", "Content of page 2"]),
(3, 4, "Content of page 3\nContent of page 4"), (3, 4, ["Content of page 3", "Content of page 4"]),
(4, 4, "Content of page 4"), (4, 4, ["Content of page 4"]),
( (
0, 0,
10, 10,
"Content of page 0\nContent of page 1\nContent of page 2\nContent of page 3\nContent of page 4", [f"Content of page {i}" for i in range(5)],
), # Out of range ), # Out of range
(5, 10, ""), # Completely out of range (5, 10, []), # Completely out of range
(3, 2, ""), # Invalid range (start > end) (3, 2, []), # Invalid range (start > end)
( (
-1, -1,
2, 2,
"Content of page 0\nContent of page 1\nContent of page 2", [f"Content of page {i}" for i in range(3)],
), # Negative start ), # Negative start
], ],
) )
@ -121,21 +121,21 @@ def test_extract_section_pages(mock_doc, mock_toc_items):
number=1, number=1,
start_page=0, start_page=0,
end_page=2, end_page=2,
content="Content of page 0\nContent of page 1\nContent of page 2", pages=["Content of page 0", "Content of page 1", "Content of page 2"],
children=[ children=[
Section( Section(
title="Section 1.1", title="Section 1.1",
number=1, number=1,
start_page=1, start_page=1,
end_page=1, end_page=1,
content="Content of page 1", pages=["Content of page 1"],
), ),
Section( Section(
title="Section 1.2", title="Section 1.2",
number=2, number=2,
start_page=2, start_page=2,
end_page=2, end_page=2,
content="Content of page 2", pages=["Content of page 2"],
), ),
], ],
) )
@ -148,21 +148,21 @@ def test_extract_sections(mock_doc):
number=1, number=1,
start_page=0, start_page=0,
end_page=2, end_page=2,
content="Content of page 0\nContent of page 1\nContent of page 2", pages=["Content of page 0", "Content of page 1", "Content of page 2"],
children=[ children=[
Section( Section(
title="Section 1.1", title="Section 1.1",
number=1, number=1,
start_page=1, start_page=1,
end_page=1, end_page=1,
content="Content of page 1", pages=["Content of page 1"],
), ),
Section( Section(
title="Section 1.2", title="Section 1.2",
number=2, number=2,
start_page=2, start_page=2,
end_page=2, end_page=2,
content="Content of page 2", pages=["Content of page 2"],
), ),
], ],
), ),
@ -171,14 +171,14 @@ def test_extract_sections(mock_doc):
number=2, number=2,
start_page=3, start_page=3,
end_page=5, end_page=5,
content="Content of page 3\nContent of page 4", pages=["Content of page 3", "Content of page 4"],
children=[ children=[
Section( Section(
title="Section 2.1", title="Section 2.1",
number=1, number=1,
start_page=4, start_page=4,
end_page=5, end_page=5,
content="Content of page 4", pages=["Content of page 4"],
), ),
], ],
), ),
@ -195,7 +195,7 @@ def test_extract_sections_no_toc(mock_doc):
number=1, number=1,
start_page=0, start_page=0,
end_page=5, end_page=5,
content="Full document content", pages=[f"Content of page {i}" for i in range(5)],
children=[], children=[],
), ),
] ]
@ -376,9 +376,9 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
# Create sections with specific content # Create sections with specific content
section1 = MagicMock() section1 = MagicMock()
section1.content = "Content of section 1" section1.pages = ["Content of section 1"]
section2 = MagicMock() section2 = MagicMock()
section2.content = "Content of section 2" section2.pages = ["Content of section 2"]
# Mock extract_sections to return our sections # Mock extract_sections to return our sections
with patch("memory.common.parsers.ebook.extract_sections") as mock_extract: with patch("memory.common.parsers.ebook.extract_sections") as mock_extract:
@ -391,4 +391,4 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
ebook = parse_ebook(test_file) ebook = parse_ebook(test_file)
# Check the full content is concatenated correctly # Check the full content is concatenated correctly
assert ebook.full_content == "Content of section 1Content of section 2" assert ebook.full_content == "Content of section 1\n\nContent of section 2"

View File

@ -17,22 +17,21 @@ def mock_ebook():
sections=[ sections=[
Section( Section(
title="Chapter 1", title="Chapter 1",
content="This is the content of chapter 1. " pages=["This is the content of chapter 1. " * 20],
* 20, # Make it long enough
number=1, number=1,
start_page=1, start_page=1,
end_page=10, end_page=10,
children=[ children=[
Section( Section(
title="Section 1.1", title="Section 1.1",
content="This is section 1.1 content. " * 15, pages=["This is section 1.1 content. " * 15],
number=1, number=1,
start_page=1, start_page=1,
end_page=5, end_page=5,
), ),
Section( Section(
title="Section 1.2", title="Section 1.2",
content="This is section 1.2 content. " * 15, pages=["This is section 1.2 content. " * 15],
number=2, number=2,
start_page=6, start_page=6,
end_page=10, end_page=10,
@ -41,7 +40,7 @@ def mock_ebook():
), ),
Section( Section(
title="Chapter 2", title="Chapter 2",
content="This is the content of chapter 2. " * 20, pages=["This is the content of chapter 2. " * 20],
number=2, number=2,
start_page=11, start_page=11,
end_page=20, end_page=20,
@ -52,7 +51,7 @@ def mock_ebook():
) )
@pytest.fixture(autouse=True) @pytest.fixture
def mock_embedding(): def mock_embedding():
"""Mock the embedding function to return dummy vectors.""" """Mock the embedding function to return dummy vectors."""
with patch("memory.workers.tasks.ebook.embedding.embed") as mock: with patch("memory.workers.tasks.ebook.embedding.embed") as mock:
@ -316,3 +315,55 @@ def test_sync_book_file_not_found():
"""Test handling of missing files.""" """Test handling of missing files."""
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
ebook.sync_book("/nonexistent/file.epub") ebook.sync_book("/nonexistent/file.epub")
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
"""Test that book sections with large pages are passed whole to the embedding function."""
# Create a test book first
book = Book(
title="Test Book",
author="Test Author",
file_path="/test/path",
)
db_session.add(book)
db_session.flush()
# Create large content that exceeds 1000 tokens (4000+ characters)
large_section_content = "This is a very long section content. " * 120 # ~4440 chars
large_page_1 = "This is page 1 with lots of content. " * 120 # ~4440 chars
large_page_2 = "This is page 2 with lots of content. " * 120 # ~4440 chars
# Create test sections with large content and pages
sections = [
BookSection(
book_id=book.id,
section_title="Test Section",
section_number=1,
section_level=1,
start_page=1,
end_page=10,
content=large_section_content,
sha256=b"test_hash",
modality="book",
tags=["book"],
pages=[large_page_1, large_page_2],
)
]
db_session.add_all(sections)
db_session.flush()
ebook.embed_sections(sections)
# Verify that the voyage client was called with the full large content
# Should be called 3 times: once for section content, twice for pages
assert mock_voyage_client.embed.call_count == 3
# Check that the full content was passed to the embedding function
calls = mock_voyage_client.embed.call_args_list
texts = [c[0][0] for c in calls]
assert texts == [
[large_section_content.strip()],
[large_page_1.strip()],
[large_page_2.strip()],
]