From e8070a35576fe28b0bc8f00cc21957deb0530b50 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 25 May 2025 11:23:19 +0200 Subject: [PATCH] proper chunk sizes for books --- src/memory/common/chunker.py | 51 ++++++++------- src/memory/common/db/models.py | 1 + src/memory/common/embedding.py | 31 +++++---- src/memory/common/extract.py | 30 +++++---- src/memory/common/parsers/ebook.py | 16 ++--- src/memory/workers/tasks/ebook.py | 31 ++++++--- tests/conftest.py | 5 +- tests/memory/common/parsers/test_ebook.py | 38 +++++------ .../memory/workers/tasks/test_ebook_tasks.py | 63 +++++++++++++++++-- 9 files changed, 178 insertions(+), 88 deletions(-) diff --git a/src/memory/common/chunker.py b/src/memory/common/chunker.py index b656776..418ae1e 100644 --- a/src/memory/common/chunker.py +++ b/src/memory/common/chunker.py @@ -6,8 +6,9 @@ logger = logging.getLogger(__name__) # Chunking configuration -MAX_TOKENS = 32000 # VoyageAI max context window -OVERLAP_TOKENS = 200 # Default overlap between chunks +EMBEDDING_MAX_TOKENS = 32000 # VoyageAI max context window +DEFAULT_CHUNK_TOKENS = 512 # Optimal chunk size for semantic search +OVERLAP_TOKENS = 50 # Default overlap between chunks (10% of chunk size) CHARS_PER_TOKEN = 4 @@ -23,11 +24,13 @@ def approx_token_count(s: str) -> int: return len(s) // CHARS_PER_TOKEN -def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]: +def yield_word_chunks( + text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS +) -> Iterable[str]: words = text.split() if not words: return - + current = "" for word in words: new_chunk = f"{current} {word}".strip() @@ -40,62 +43,64 @@ def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]: yield current -def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]: +def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[str]: """ Yield text spans in priority order: paragraphs, sentences, words. Each span is guaranteed to be under max_tokens. - + Args: text: The text to split max_tokens: Maximum tokens per chunk - + Yields: Spans of text that fit within token limits """ # Return early for empty text if not text.strip(): return - + for paragraph in text.split("\n\n"): if not paragraph.strip(): continue - + if approx_token_count(paragraph) <= max_tokens: yield paragraph continue - + for sentence in _SENT_SPLIT_RE.split(paragraph): if not sentence.strip(): continue - + if approx_token_count(sentence) <= max_tokens: yield sentence continue - + for chunk in yield_word_chunks(sentence, max_tokens): yield chunk -def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]: +def chunk_text( + text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS, overlap: int = OVERLAP_TOKENS +) -> Iterable[str]: """ Split text into chunks respecting semantic boundaries while staying within token limits. - + Args: text: The text to chunk - max_tokens: Maximum tokens per chunk (default: VoyageAI max context) - overlap: Number of tokens to overlap between chunks (default: 200) - + max_tokens: Maximum tokens per chunk (default: 512 for optimal semantic search) + overlap: Number of tokens to overlap between chunks (default: 50) + Returns: List of text chunks """ text = text.strip() if not text: return - + if approx_token_count(text) <= max_tokens: yield text return - + overlap_chars = overlap * CHARS_PER_TOKEN current = "" @@ -108,19 +113,17 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T yield current current = "" continue - + overlap_text = current[-overlap_chars:] clean_break = max( - overlap_text.rfind(". "), - overlap_text.rfind("! "), - overlap_text.rfind("? ") + overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ") ) if clean_break < 0: yield current current = "" continue - + break_offset = -overlap_chars + clean_break + 1 chunk = current[break_offset:].strip() yield current diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index b8c4902..4a5b01f 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -481,6 +481,7 @@ class BookSection(SourceItem): backref="children", foreign_keys=[parent_section_id], ) + pages: list[str] = [] __mapper_args__ = {"polymorphic_identity": "book_section"} __table_args__ = ( diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 1332355..be88ac1 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -7,18 +7,12 @@ import voyageai from PIL import Image from memory.common import extract, settings -from memory.common.chunker import chunk_text +from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS from memory.common.db.models import Chunk logger = logging.getLogger(__name__) -# Chunking configuration -MAX_TOKENS = 32000 # VoyageAI max context window -OVERLAP_TOKENS = 200 # Default overlap between chunks -CHARS_PER_TOKEN = 4 - - DistanceType = Literal["Cosine", "Dot", "Euclidean"] Vector = list[float] @@ -150,12 +144,13 @@ def embed_text( texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", + chunk_size: int = DEFAULT_CHUNK_TOKENS, ) -> list[Vector]: chunks = [ c for text in texts if isinstance(text, str) - for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) + for c in chunk_text(text, chunk_size, OVERLAP_TOKENS) if c.strip() ] if not chunks: @@ -179,11 +174,12 @@ def embed_mixed( items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", + chunk_size: int = DEFAULT_CHUNK_TOKENS, ) -> list[Vector]: def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]: if isinstance(item, str): return [ - c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip() + c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip() ] return [item] @@ -193,13 +189,17 @@ def embed_mixed( def embed_page(page: extract.Page) -> list[Vector]: contents = page["contents"] + chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS) if all(isinstance(c, str) for c in contents): return embed_text( - cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL + cast(list[str], contents), + model=settings.TEXT_EMBEDDING_MODEL, + chunk_size=chunk_size, ) return embed_mixed( cast(list[extract.MulitmodalChunk], contents), model=settings.MIXED_EMBEDDING_MODEL, + chunk_size=chunk_size, ) @@ -255,9 +255,10 @@ def embed( mime_type: str, content: bytes | str | pathlib.Path, metadata: dict[str, Any] = {}, + chunk_size: int | None = None, ) -> tuple[str, list[Chunk]]: modality = get_modality(mime_type) - pages = extract.extract_content(mime_type, content) + pages = extract.extract_content(mime_type, content, chunk_size=chunk_size) chunks = [ make_chunk(page, vector, metadata) for page in pages @@ -266,13 +267,17 @@ def embed( return modality, chunks -def embed_image(file_path: pathlib.Path, texts: list[str]) -> Chunk: +def embed_image( + file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None +) -> Chunk: image = Image.open(file_path) mime_type = get_mimetype(image) if mime_type is None: raise ValueError("Unsupported image format") - vector = embed_mixed([image] + texts)[0] + vector = embed_mixed( + [image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS + )[0] return Chunk( id=str(uuid.uuid4()), diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index dc21892..83bfae0 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -3,7 +3,7 @@ import logging import pathlib import tempfile from contextlib import contextmanager -from typing import Any, Generator, Sequence, TypedDict, cast +from typing import Any, Generator, NotRequired, Sequence, TypedDict, cast import pymupdf # PyMuPDF import pypandoc @@ -17,6 +17,8 @@ MulitmodalChunk = Image.Image | str class Page(TypedDict): contents: Sequence[MulitmodalChunk] metadata: dict[str, Any] + # This is used to override the default chunk size for the page + chunk_size: NotRequired[int] @contextmanager @@ -108,22 +110,28 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]: return [{"contents": [cast(str, content)], "metadata": {}}] -def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]: +def extract_content( + mime_type: str, + content: bytes | str | pathlib.Path, + chunk_size: int | None = None, +) -> list[Page]: + pages = [] logger.info(f"Extracting content from {mime_type}") if mime_type == "application/pdf": - return doc_to_images(content) - if mime_type in [ + pages = doc_to_images(content) + elif mime_type in [ "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/msword", ]: logger.info(f"Extracting content from {content}") pages = extract_docx(content) logger.info(f"Extracted {len(pages)} pages from {content}") - return pages - if mime_type.startswith("text/"): - return extract_text(content) - if mime_type.startswith("image/"): - return extract_image(content) + elif mime_type.startswith("text/"): + pages = extract_text(content) + elif mime_type.startswith("image/"): + pages = extract_image(content) - # Return empty list for unknown mime types - return [] + if chunk_size: + pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages] + + return pages diff --git a/src/memory/common/parsers/ebook.py b/src/memory/common/parsers/ebook.py index 1bd8186..1333fa1 100644 --- a/src/memory/common/parsers/ebook.py +++ b/src/memory/common/parsers/ebook.py @@ -13,7 +13,7 @@ class Section: """Represents a chapter or section in an ebook.""" title: str - content: str + pages: list[str] number: int | None = None start_page: int | None = None end_page: int | None = None @@ -74,13 +74,13 @@ def extract_epub_metadata(doc) -> dict[str, Any]: return {key: value for key, value in doc.metadata.items() if value} -def get_pages(doc, start_page: int, end_page: int) -> str: +def get_pages(doc, start_page: int, end_page: int) -> list[str]: pages = [ doc[page_num].get_text() for page_num in range(start_page, end_page + 1) if 0 <= page_num < doc.page_count ] - return "\n".join(pages) + return pages def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | None: @@ -96,7 +96,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | if not next_item: return Section( title=name, - content=get_pages(doc, page, doc.page_count), + pages=get_pages(doc, page, doc.page_count), number=section_num, start_page=page, end_page=doc.page_count, @@ -110,7 +110,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | last_page = next_item[2] - 1 if next_item else doc.page_count return Section( title=name, - content=get_pages(doc, page, last_page), + pages=get_pages(doc, page, last_page), number=section_num, start_page=page, end_page=last_page, @@ -125,7 +125,7 @@ def extract_sections(doc) -> list[Section]: return [ Section( title="Content", - content=doc.get_text(), + pages=get_pages(doc, 0, doc.page_count), number=1, start_page=0, end_page=doc.page_count, @@ -169,7 +169,9 @@ def parse_ebook(file_path: str | Path) -> Ebook: sections = extract_sections(doc) full_content = "" if sections: - full_content = "".join(section.content for section in sections) + full_content = "\n\n".join( + page for section in sections for page in section.pages + ) return Ebook( title=title, diff --git a/src/memory/workers/tasks/ebook.py b/src/memory/workers/tasks/ebook.py index cf5d02a..919f7f8 100644 --- a/src/memory/workers/tasks/ebook.py +++ b/src/memory/workers/tasks/ebook.py @@ -3,7 +3,7 @@ import logging from pathlib import Path from typing import Iterable, cast -from memory.common import embedding, qdrant, settings +from memory.common import chunker, embedding, qdrant from memory.common.db.connection import make_session from memory.common.db.models import Book, BookSection from memory.common.parsers.ebook import Ebook, parse_ebook, Section @@ -44,7 +44,8 @@ def section_processor( level: int = 1, parent_key: tuple[int, int | None] | None = None, ): - if len(section.content.strip()) >= MIN_SECTION_LENGTH: + content = "\n\n".join(section.pages).strip() + if len(content) >= MIN_SECTION_LENGTH: sha256 = hashlib.sha256( f"{book.id}:{section.title}:{section.start_page}".encode() ).digest() @@ -57,10 +58,11 @@ def section_processor( start_page=section.start_page, end_page=section.end_page, parent_section_id=None, # Will be set after flush - content=section.content, + content=content, sha256=sha256, modality="book", tags=book.tags, + pages=section.pages, ) all_sections.append(book_section) @@ -127,13 +129,28 @@ def embed_sections(all_sections: list[BookSection]) -> int: """Embed all sections and return count of successfully embedded sections.""" embedded_count = 0 + def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]: + _, chunks = embedding.embed( + "text/plain", + text, + metadata=metadata, + chunk_size=chunker.EMBEDDING_MAX_TOKENS, + ) + return chunks + for section in all_sections: try: - _, chunks = embedding.embed( - "text/plain", - cast(str, section.content), - metadata=section.as_payload(), + section_chunks = embed_text( + cast(str, section.content), section.as_payload() ) + page_chunks = [ + chunk + for i, page in enumerate(section.pages) + for chunk in embed_text( + page, section.as_payload() | {"page_number": i + section.start_page} + ) + ] + chunks = section_chunks + page_chunks if chunks: section.chunks = chunks diff --git a/tests/conftest.py b/tests/conftest.py index 834d60f..a3edd06 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -222,4 +222,7 @@ def qdrant(): @pytest.fixture(autouse=True) def mock_voyage_client(): with patch.object(voyageai, "Client", autospec=True) as mock_client: - yield mock_client() + client = mock_client() + client.embed.return_value.embeddings = [[0.1] * 1024] + client.multimodal_embed.return_value.embeddings = [[0.1] * 1024] + yield client diff --git a/tests/memory/common/parsers/test_ebook.py b/tests/memory/common/parsers/test_ebook.py index 2172ce9..44a170d 100644 --- a/tests/memory/common/parsers/test_ebook.py +++ b/tests/memory/common/parsers/test_ebook.py @@ -83,20 +83,20 @@ def test_extract_epub_metadata(metadata_input, expected): @pytest.mark.parametrize( "start_page,end_page,expected_content", [ - (0, 2, "Content of page 0\nContent of page 1\nContent of page 2"), - (3, 4, "Content of page 3\nContent of page 4"), - (4, 4, "Content of page 4"), + (0, 2, ["Content of page 0", "Content of page 1", "Content of page 2"]), + (3, 4, ["Content of page 3", "Content of page 4"]), + (4, 4, ["Content of page 4"]), ( 0, 10, - "Content of page 0\nContent of page 1\nContent of page 2\nContent of page 3\nContent of page 4", + [f"Content of page {i}" for i in range(5)], ), # Out of range - (5, 10, ""), # Completely out of range - (3, 2, ""), # Invalid range (start > end) + (5, 10, []), # Completely out of range + (3, 2, []), # Invalid range (start > end) ( -1, 2, - "Content of page 0\nContent of page 1\nContent of page 2", + [f"Content of page {i}" for i in range(3)], ), # Negative start ], ) @@ -121,21 +121,21 @@ def test_extract_section_pages(mock_doc, mock_toc_items): number=1, start_page=0, end_page=2, - content="Content of page 0\nContent of page 1\nContent of page 2", + pages=["Content of page 0", "Content of page 1", "Content of page 2"], children=[ Section( title="Section 1.1", number=1, start_page=1, end_page=1, - content="Content of page 1", + pages=["Content of page 1"], ), Section( title="Section 1.2", number=2, start_page=2, end_page=2, - content="Content of page 2", + pages=["Content of page 2"], ), ], ) @@ -148,21 +148,21 @@ def test_extract_sections(mock_doc): number=1, start_page=0, end_page=2, - content="Content of page 0\nContent of page 1\nContent of page 2", + pages=["Content of page 0", "Content of page 1", "Content of page 2"], children=[ Section( title="Section 1.1", number=1, start_page=1, end_page=1, - content="Content of page 1", + pages=["Content of page 1"], ), Section( title="Section 1.2", number=2, start_page=2, end_page=2, - content="Content of page 2", + pages=["Content of page 2"], ), ], ), @@ -171,14 +171,14 @@ def test_extract_sections(mock_doc): number=2, start_page=3, end_page=5, - content="Content of page 3\nContent of page 4", + pages=["Content of page 3", "Content of page 4"], children=[ Section( title="Section 2.1", number=1, start_page=4, end_page=5, - content="Content of page 4", + pages=["Content of page 4"], ), ], ), @@ -195,7 +195,7 @@ def test_extract_sections_no_toc(mock_doc): number=1, start_page=0, end_page=5, - content="Full document content", + pages=[f"Content of page {i}" for i in range(5)], children=[], ), ] @@ -376,9 +376,9 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path): # Create sections with specific content section1 = MagicMock() - section1.content = "Content of section 1" + section1.pages = ["Content of section 1"] section2 = MagicMock() - section2.content = "Content of section 2" + section2.pages = ["Content of section 2"] # Mock extract_sections to return our sections with patch("memory.common.parsers.ebook.extract_sections") as mock_extract: @@ -391,4 +391,4 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path): ebook = parse_ebook(test_file) # Check the full content is concatenated correctly - assert ebook.full_content == "Content of section 1Content of section 2" + assert ebook.full_content == "Content of section 1\n\nContent of section 2" diff --git a/tests/memory/workers/tasks/test_ebook_tasks.py b/tests/memory/workers/tasks/test_ebook_tasks.py index c0d0abe..6ace439 100644 --- a/tests/memory/workers/tasks/test_ebook_tasks.py +++ b/tests/memory/workers/tasks/test_ebook_tasks.py @@ -17,22 +17,21 @@ def mock_ebook(): sections=[ Section( title="Chapter 1", - content="This is the content of chapter 1. " - * 20, # Make it long enough + pages=["This is the content of chapter 1. " * 20], number=1, start_page=1, end_page=10, children=[ Section( title="Section 1.1", - content="This is section 1.1 content. " * 15, + pages=["This is section 1.1 content. " * 15], number=1, start_page=1, end_page=5, ), Section( title="Section 1.2", - content="This is section 1.2 content. " * 15, + pages=["This is section 1.2 content. " * 15], number=2, start_page=6, end_page=10, @@ -41,7 +40,7 @@ def mock_ebook(): ), Section( title="Chapter 2", - content="This is the content of chapter 2. " * 20, + pages=["This is the content of chapter 2. " * 20], number=2, start_page=11, end_page=20, @@ -52,7 +51,7 @@ def mock_ebook(): ) -@pytest.fixture(autouse=True) +@pytest.fixture def mock_embedding(): """Mock the embedding function to return dummy vectors.""" with patch("memory.workers.tasks.ebook.embedding.embed") as mock: @@ -316,3 +315,55 @@ def test_sync_book_file_not_found(): """Test handling of missing files.""" with pytest.raises(FileNotFoundError): ebook.sync_book("/nonexistent/file.epub") + + +def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client): + """Test that book sections with large pages are passed whole to the embedding function.""" + # Create a test book first + book = Book( + title="Test Book", + author="Test Author", + file_path="/test/path", + ) + db_session.add(book) + db_session.flush() + + # Create large content that exceeds 1000 tokens (4000+ characters) + large_section_content = "This is a very long section content. " * 120 # ~4440 chars + large_page_1 = "This is page 1 with lots of content. " * 120 # ~4440 chars + large_page_2 = "This is page 2 with lots of content. " * 120 # ~4440 chars + + # Create test sections with large content and pages + sections = [ + BookSection( + book_id=book.id, + section_title="Test Section", + section_number=1, + section_level=1, + start_page=1, + end_page=10, + content=large_section_content, + sha256=b"test_hash", + modality="book", + tags=["book"], + pages=[large_page_1, large_page_2], + ) + ] + + db_session.add_all(sections) + db_session.flush() + + ebook.embed_sections(sections) + + # Verify that the voyage client was called with the full large content + # Should be called 3 times: once for section content, twice for pages + assert mock_voyage_client.embed.call_count == 3 + + # Check that the full content was passed to the embedding function + calls = mock_voyage_client.embed.call_args_list + texts = [c[0][0] for c in calls] + assert texts == [ + [large_section_content.strip()], + [large_page_1.strip()], + [large_page_2.strip()], + ]