proper chunk sizes for books

This commit is contained in:
Daniel O'Connell 2025-05-25 11:23:19 +02:00
parent eb69221999
commit e8070a3557
9 changed files with 178 additions and 88 deletions

View File

@ -6,8 +6,9 @@ logger = logging.getLogger(__name__)
# Chunking configuration
MAX_TOKENS = 32000 # VoyageAI max context window
OVERLAP_TOKENS = 200 # Default overlap between chunks
EMBEDDING_MAX_TOKENS = 32000 # VoyageAI max context window
DEFAULT_CHUNK_TOKENS = 512 # Optimal chunk size for semantic search
OVERLAP_TOKENS = 50 # Default overlap between chunks (10% of chunk size)
CHARS_PER_TOKEN = 4
@ -23,7 +24,9 @@ def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
def yield_word_chunks(
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
) -> Iterable[str]:
words = text.split()
if not words:
return
@ -40,7 +43,7 @@ def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
yield current
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[str]:
"""
Yield text spans in priority order: paragraphs, sentences, words.
Each span is guaranteed to be under max_tokens.
@ -76,14 +79,16 @@ def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
yield chunk
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
def chunk_text(
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS, overlap: int = OVERLAP_TOKENS
) -> Iterable[str]:
"""
Split text into chunks respecting semantic boundaries while staying within token limits.
Args:
text: The text to chunk
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
overlap: Number of tokens to overlap between chunks (default: 200)
max_tokens: Maximum tokens per chunk (default: 512 for optimal semantic search)
overlap: Number of tokens to overlap between chunks (default: 50)
Returns:
List of text chunks
@ -111,9 +116,7 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
overlap_text = current[-overlap_chars:]
clean_break = max(
overlap_text.rfind(". "),
overlap_text.rfind("! "),
overlap_text.rfind("? ")
overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ")
)
if clean_break < 0:

View File

@ -481,6 +481,7 @@ class BookSection(SourceItem):
backref="children",
foreign_keys=[parent_section_id],
)
pages: list[str] = []
__mapper_args__ = {"polymorphic_identity": "book_section"}
__table_args__ = (

View File

@ -7,18 +7,12 @@ import voyageai
from PIL import Image
from memory.common import extract, settings
from memory.common.chunker import chunk_text
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
# Chunking configuration
MAX_TOKENS = 32000 # VoyageAI max context window
OVERLAP_TOKENS = 200 # Default overlap between chunks
CHARS_PER_TOKEN = 4
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
@ -150,12 +144,13 @@ def embed_text(
texts: list[str],
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]:
chunks = [
c
for text in texts
if isinstance(text, str)
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
if c.strip()
]
if not chunks:
@ -179,11 +174,12 @@ def embed_mixed(
items: list[extract.MulitmodalChunk],
model: str = settings.MIXED_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
if isinstance(item, str):
return [
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
]
return [item]
@ -193,13 +189,17 @@ def embed_mixed(
def embed_page(page: extract.Page) -> list[Vector]:
contents = page["contents"]
chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS)
if all(isinstance(c, str) for c in contents):
return embed_text(
cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL
cast(list[str], contents),
model=settings.TEXT_EMBEDDING_MODEL,
chunk_size=chunk_size,
)
return embed_mixed(
cast(list[extract.MulitmodalChunk], contents),
model=settings.MIXED_EMBEDDING_MODEL,
chunk_size=chunk_size,
)
@ -255,9 +255,10 @@ def embed(
mime_type: str,
content: bytes | str | pathlib.Path,
metadata: dict[str, Any] = {},
chunk_size: int | None = None,
) -> tuple[str, list[Chunk]]:
modality = get_modality(mime_type)
pages = extract.extract_content(mime_type, content)
pages = extract.extract_content(mime_type, content, chunk_size=chunk_size)
chunks = [
make_chunk(page, vector, metadata)
for page in pages
@ -266,13 +267,17 @@ def embed(
return modality, chunks
def embed_image(file_path: pathlib.Path, texts: list[str]) -> Chunk:
def embed_image(
file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None
) -> Chunk:
image = Image.open(file_path)
mime_type = get_mimetype(image)
if mime_type is None:
raise ValueError("Unsupported image format")
vector = embed_mixed([image] + texts)[0]
vector = embed_mixed(
[image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS
)[0]
return Chunk(
id=str(uuid.uuid4()),

View File

@ -3,7 +3,7 @@ import logging
import pathlib
import tempfile
from contextlib import contextmanager
from typing import Any, Generator, Sequence, TypedDict, cast
from typing import Any, Generator, NotRequired, Sequence, TypedDict, cast
import pymupdf # PyMuPDF
import pypandoc
@ -17,6 +17,8 @@ MulitmodalChunk = Image.Image | str
class Page(TypedDict):
contents: Sequence[MulitmodalChunk]
metadata: dict[str, Any]
# This is used to override the default chunk size for the page
chunk_size: NotRequired[int]
@contextmanager
@ -108,22 +110,28 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
return [{"contents": [cast(str, content)], "metadata": {}}]
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
def extract_content(
mime_type: str,
content: bytes | str | pathlib.Path,
chunk_size: int | None = None,
) -> list[Page]:
pages = []
logger.info(f"Extracting content from {mime_type}")
if mime_type == "application/pdf":
return doc_to_images(content)
if mime_type in [
pages = doc_to_images(content)
elif mime_type in [
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/msword",
]:
logger.info(f"Extracting content from {content}")
pages = extract_docx(content)
logger.info(f"Extracted {len(pages)} pages from {content}")
return pages
if mime_type.startswith("text/"):
return extract_text(content)
if mime_type.startswith("image/"):
return extract_image(content)
elif mime_type.startswith("text/"):
pages = extract_text(content)
elif mime_type.startswith("image/"):
pages = extract_image(content)
# Return empty list for unknown mime types
return []
if chunk_size:
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
return pages

View File

@ -13,7 +13,7 @@ class Section:
"""Represents a chapter or section in an ebook."""
title: str
content: str
pages: list[str]
number: int | None = None
start_page: int | None = None
end_page: int | None = None
@ -74,13 +74,13 @@ def extract_epub_metadata(doc) -> dict[str, Any]:
return {key: value for key, value in doc.metadata.items() if value}
def get_pages(doc, start_page: int, end_page: int) -> str:
def get_pages(doc, start_page: int, end_page: int) -> list[str]:
pages = [
doc[page_num].get_text()
for page_num in range(start_page, end_page + 1)
if 0 <= page_num < doc.page_count
]
return "\n".join(pages)
return pages
def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | None:
@ -96,7 +96,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section |
if not next_item:
return Section(
title=name,
content=get_pages(doc, page, doc.page_count),
pages=get_pages(doc, page, doc.page_count),
number=section_num,
start_page=page,
end_page=doc.page_count,
@ -110,7 +110,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section |
last_page = next_item[2] - 1 if next_item else doc.page_count
return Section(
title=name,
content=get_pages(doc, page, last_page),
pages=get_pages(doc, page, last_page),
number=section_num,
start_page=page,
end_page=last_page,
@ -125,7 +125,7 @@ def extract_sections(doc) -> list[Section]:
return [
Section(
title="Content",
content=doc.get_text(),
pages=get_pages(doc, 0, doc.page_count),
number=1,
start_page=0,
end_page=doc.page_count,
@ -169,7 +169,9 @@ def parse_ebook(file_path: str | Path) -> Ebook:
sections = extract_sections(doc)
full_content = ""
if sections:
full_content = "".join(section.content for section in sections)
full_content = "\n\n".join(
page for section in sections for page in section.pages
)
return Ebook(
title=title,

View File

@ -3,7 +3,7 @@ import logging
from pathlib import Path
from typing import Iterable, cast
from memory.common import embedding, qdrant, settings
from memory.common import chunker, embedding, qdrant
from memory.common.db.connection import make_session
from memory.common.db.models import Book, BookSection
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
@ -44,7 +44,8 @@ def section_processor(
level: int = 1,
parent_key: tuple[int, int | None] | None = None,
):
if len(section.content.strip()) >= MIN_SECTION_LENGTH:
content = "\n\n".join(section.pages).strip()
if len(content) >= MIN_SECTION_LENGTH:
sha256 = hashlib.sha256(
f"{book.id}:{section.title}:{section.start_page}".encode()
).digest()
@ -57,10 +58,11 @@ def section_processor(
start_page=section.start_page,
end_page=section.end_page,
parent_section_id=None, # Will be set after flush
content=section.content,
content=content,
sha256=sha256,
modality="book",
tags=book.tags,
pages=section.pages,
)
all_sections.append(book_section)
@ -127,13 +129,28 @@ def embed_sections(all_sections: list[BookSection]) -> int:
"""Embed all sections and return count of successfully embedded sections."""
embedded_count = 0
def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]:
_, chunks = embedding.embed(
"text/plain",
text,
metadata=metadata,
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
)
return chunks
for section in all_sections:
try:
_, chunks = embedding.embed(
"text/plain",
cast(str, section.content),
metadata=section.as_payload(),
section_chunks = embed_text(
cast(str, section.content), section.as_payload()
)
page_chunks = [
chunk
for i, page in enumerate(section.pages)
for chunk in embed_text(
page, section.as_payload() | {"page_number": i + section.start_page}
)
]
chunks = section_chunks + page_chunks
if chunks:
section.chunks = chunks

View File

@ -222,4 +222,7 @@ def qdrant():
@pytest.fixture(autouse=True)
def mock_voyage_client():
with patch.object(voyageai, "Client", autospec=True) as mock_client:
yield mock_client()
client = mock_client()
client.embed.return_value.embeddings = [[0.1] * 1024]
client.multimodal_embed.return_value.embeddings = [[0.1] * 1024]
yield client

View File

@ -83,20 +83,20 @@ def test_extract_epub_metadata(metadata_input, expected):
@pytest.mark.parametrize(
"start_page,end_page,expected_content",
[
(0, 2, "Content of page 0\nContent of page 1\nContent of page 2"),
(3, 4, "Content of page 3\nContent of page 4"),
(4, 4, "Content of page 4"),
(0, 2, ["Content of page 0", "Content of page 1", "Content of page 2"]),
(3, 4, ["Content of page 3", "Content of page 4"]),
(4, 4, ["Content of page 4"]),
(
0,
10,
"Content of page 0\nContent of page 1\nContent of page 2\nContent of page 3\nContent of page 4",
[f"Content of page {i}" for i in range(5)],
), # Out of range
(5, 10, ""), # Completely out of range
(3, 2, ""), # Invalid range (start > end)
(5, 10, []), # Completely out of range
(3, 2, []), # Invalid range (start > end)
(
-1,
2,
"Content of page 0\nContent of page 1\nContent of page 2",
[f"Content of page {i}" for i in range(3)],
), # Negative start
],
)
@ -121,21 +121,21 @@ def test_extract_section_pages(mock_doc, mock_toc_items):
number=1,
start_page=0,
end_page=2,
content="Content of page 0\nContent of page 1\nContent of page 2",
pages=["Content of page 0", "Content of page 1", "Content of page 2"],
children=[
Section(
title="Section 1.1",
number=1,
start_page=1,
end_page=1,
content="Content of page 1",
pages=["Content of page 1"],
),
Section(
title="Section 1.2",
number=2,
start_page=2,
end_page=2,
content="Content of page 2",
pages=["Content of page 2"],
),
],
)
@ -148,21 +148,21 @@ def test_extract_sections(mock_doc):
number=1,
start_page=0,
end_page=2,
content="Content of page 0\nContent of page 1\nContent of page 2",
pages=["Content of page 0", "Content of page 1", "Content of page 2"],
children=[
Section(
title="Section 1.1",
number=1,
start_page=1,
end_page=1,
content="Content of page 1",
pages=["Content of page 1"],
),
Section(
title="Section 1.2",
number=2,
start_page=2,
end_page=2,
content="Content of page 2",
pages=["Content of page 2"],
),
],
),
@ -171,14 +171,14 @@ def test_extract_sections(mock_doc):
number=2,
start_page=3,
end_page=5,
content="Content of page 3\nContent of page 4",
pages=["Content of page 3", "Content of page 4"],
children=[
Section(
title="Section 2.1",
number=1,
start_page=4,
end_page=5,
content="Content of page 4",
pages=["Content of page 4"],
),
],
),
@ -195,7 +195,7 @@ def test_extract_sections_no_toc(mock_doc):
number=1,
start_page=0,
end_page=5,
content="Full document content",
pages=[f"Content of page {i}" for i in range(5)],
children=[],
),
]
@ -376,9 +376,9 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
# Create sections with specific content
section1 = MagicMock()
section1.content = "Content of section 1"
section1.pages = ["Content of section 1"]
section2 = MagicMock()
section2.content = "Content of section 2"
section2.pages = ["Content of section 2"]
# Mock extract_sections to return our sections
with patch("memory.common.parsers.ebook.extract_sections") as mock_extract:
@ -391,4 +391,4 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
ebook = parse_ebook(test_file)
# Check the full content is concatenated correctly
assert ebook.full_content == "Content of section 1Content of section 2"
assert ebook.full_content == "Content of section 1\n\nContent of section 2"

View File

@ -17,22 +17,21 @@ def mock_ebook():
sections=[
Section(
title="Chapter 1",
content="This is the content of chapter 1. "
* 20, # Make it long enough
pages=["This is the content of chapter 1. " * 20],
number=1,
start_page=1,
end_page=10,
children=[
Section(
title="Section 1.1",
content="This is section 1.1 content. " * 15,
pages=["This is section 1.1 content. " * 15],
number=1,
start_page=1,
end_page=5,
),
Section(
title="Section 1.2",
content="This is section 1.2 content. " * 15,
pages=["This is section 1.2 content. " * 15],
number=2,
start_page=6,
end_page=10,
@ -41,7 +40,7 @@ def mock_ebook():
),
Section(
title="Chapter 2",
content="This is the content of chapter 2. " * 20,
pages=["This is the content of chapter 2. " * 20],
number=2,
start_page=11,
end_page=20,
@ -52,7 +51,7 @@ def mock_ebook():
)
@pytest.fixture(autouse=True)
@pytest.fixture
def mock_embedding():
"""Mock the embedding function to return dummy vectors."""
with patch("memory.workers.tasks.ebook.embedding.embed") as mock:
@ -316,3 +315,55 @@ def test_sync_book_file_not_found():
"""Test handling of missing files."""
with pytest.raises(FileNotFoundError):
ebook.sync_book("/nonexistent/file.epub")
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
"""Test that book sections with large pages are passed whole to the embedding function."""
# Create a test book first
book = Book(
title="Test Book",
author="Test Author",
file_path="/test/path",
)
db_session.add(book)
db_session.flush()
# Create large content that exceeds 1000 tokens (4000+ characters)
large_section_content = "This is a very long section content. " * 120 # ~4440 chars
large_page_1 = "This is page 1 with lots of content. " * 120 # ~4440 chars
large_page_2 = "This is page 2 with lots of content. " * 120 # ~4440 chars
# Create test sections with large content and pages
sections = [
BookSection(
book_id=book.id,
section_title="Test Section",
section_number=1,
section_level=1,
start_page=1,
end_page=10,
content=large_section_content,
sha256=b"test_hash",
modality="book",
tags=["book"],
pages=[large_page_1, large_page_2],
)
]
db_session.add_all(sections)
db_session.flush()
ebook.embed_sections(sections)
# Verify that the voyage client was called with the full large content
# Should be called 3 times: once for section content, twice for pages
assert mock_voyage_client.embed.call_count == 3
# Check that the full content was passed to the embedding function
calls = mock_voyage_client.embed.call_args_list
texts = [c[0][0] for c in calls]
assert texts == [
[large_section_content.strip()],
[large_page_1.strip()],
[large_page_2.strip()],
]