mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 21:34:42 +02:00
proper chunk sizes for books
This commit is contained in:
parent
eb69221999
commit
e8070a3557
@ -6,8 +6,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# Chunking configuration
|
# Chunking configuration
|
||||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
EMBEDDING_MAX_TOKENS = 32000 # VoyageAI max context window
|
||||||
OVERLAP_TOKENS = 200 # Default overlap between chunks
|
DEFAULT_CHUNK_TOKENS = 512 # Optimal chunk size for semantic search
|
||||||
|
OVERLAP_TOKENS = 50 # Default overlap between chunks (10% of chunk size)
|
||||||
CHARS_PER_TOKEN = 4
|
CHARS_PER_TOKEN = 4
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +24,9 @@ def approx_token_count(s: str) -> int:
|
|||||||
return len(s) // CHARS_PER_TOKEN
|
return len(s) // CHARS_PER_TOKEN
|
||||||
|
|
||||||
|
|
||||||
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
def yield_word_chunks(
|
||||||
|
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
||||||
|
) -> Iterable[str]:
|
||||||
words = text.split()
|
words = text.split()
|
||||||
if not words:
|
if not words:
|
||||||
return
|
return
|
||||||
@ -40,7 +43,7 @@ def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
|||||||
yield current
|
yield current
|
||||||
|
|
||||||
|
|
||||||
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[str]:
|
||||||
"""
|
"""
|
||||||
Yield text spans in priority order: paragraphs, sentences, words.
|
Yield text spans in priority order: paragraphs, sentences, words.
|
||||||
Each span is guaranteed to be under max_tokens.
|
Each span is guaranteed to be under max_tokens.
|
||||||
@ -76,14 +79,16 @@ def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
|
def chunk_text(
|
||||||
|
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS, overlap: int = OVERLAP_TOKENS
|
||||||
|
) -> Iterable[str]:
|
||||||
"""
|
"""
|
||||||
Split text into chunks respecting semantic boundaries while staying within token limits.
|
Split text into chunks respecting semantic boundaries while staying within token limits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to chunk
|
text: The text to chunk
|
||||||
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
|
max_tokens: Maximum tokens per chunk (default: 512 for optimal semantic search)
|
||||||
overlap: Number of tokens to overlap between chunks (default: 200)
|
overlap: Number of tokens to overlap between chunks (default: 50)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of text chunks
|
List of text chunks
|
||||||
@ -111,9 +116,7 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
|
|||||||
|
|
||||||
overlap_text = current[-overlap_chars:]
|
overlap_text = current[-overlap_chars:]
|
||||||
clean_break = max(
|
clean_break = max(
|
||||||
overlap_text.rfind(". "),
|
overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ")
|
||||||
overlap_text.rfind("! "),
|
|
||||||
overlap_text.rfind("? ")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if clean_break < 0:
|
if clean_break < 0:
|
||||||
|
@ -481,6 +481,7 @@ class BookSection(SourceItem):
|
|||||||
backref="children",
|
backref="children",
|
||||||
foreign_keys=[parent_section_id],
|
foreign_keys=[parent_section_id],
|
||||||
)
|
)
|
||||||
|
pages: list[str] = []
|
||||||
|
|
||||||
__mapper_args__ = {"polymorphic_identity": "book_section"}
|
__mapper_args__ = {"polymorphic_identity": "book_section"}
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
@ -7,18 +7,12 @@ import voyageai
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
from memory.common.chunker import chunk_text
|
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
|
||||||
from memory.common.db.models import Chunk
|
from memory.common.db.models import Chunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Chunking configuration
|
|
||||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
|
||||||
OVERLAP_TOKENS = 200 # Default overlap between chunks
|
|
||||||
CHARS_PER_TOKEN = 4
|
|
||||||
|
|
||||||
|
|
||||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||||
Vector = list[float]
|
Vector = list[float]
|
||||||
|
|
||||||
@ -150,12 +144,13 @@ def embed_text(
|
|||||||
texts: list[str],
|
texts: list[str],
|
||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
|
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
chunks = [
|
chunks = [
|
||||||
c
|
c
|
||||||
for text in texts
|
for text in texts
|
||||||
if isinstance(text, str)
|
if isinstance(text, str)
|
||||||
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
|
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
|
||||||
if c.strip()
|
if c.strip()
|
||||||
]
|
]
|
||||||
if not chunks:
|
if not chunks:
|
||||||
@ -179,11 +174,12 @@ def embed_mixed(
|
|||||||
items: list[extract.MulitmodalChunk],
|
items: list[extract.MulitmodalChunk],
|
||||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
|
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
return [
|
return [
|
||||||
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
|
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
|
||||||
]
|
]
|
||||||
return [item]
|
return [item]
|
||||||
|
|
||||||
@ -193,13 +189,17 @@ def embed_mixed(
|
|||||||
|
|
||||||
def embed_page(page: extract.Page) -> list[Vector]:
|
def embed_page(page: extract.Page) -> list[Vector]:
|
||||||
contents = page["contents"]
|
contents = page["contents"]
|
||||||
|
chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS)
|
||||||
if all(isinstance(c, str) for c in contents):
|
if all(isinstance(c, str) for c in contents):
|
||||||
return embed_text(
|
return embed_text(
|
||||||
cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL
|
cast(list[str], contents),
|
||||||
|
model=settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
return embed_mixed(
|
return embed_mixed(
|
||||||
cast(list[extract.MulitmodalChunk], contents),
|
cast(list[extract.MulitmodalChunk], contents),
|
||||||
model=settings.MIXED_EMBEDDING_MODEL,
|
model=settings.MIXED_EMBEDDING_MODEL,
|
||||||
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -255,9 +255,10 @@ def embed(
|
|||||||
mime_type: str,
|
mime_type: str,
|
||||||
content: bytes | str | pathlib.Path,
|
content: bytes | str | pathlib.Path,
|
||||||
metadata: dict[str, Any] = {},
|
metadata: dict[str, Any] = {},
|
||||||
|
chunk_size: int | None = None,
|
||||||
) -> tuple[str, list[Chunk]]:
|
) -> tuple[str, list[Chunk]]:
|
||||||
modality = get_modality(mime_type)
|
modality = get_modality(mime_type)
|
||||||
pages = extract.extract_content(mime_type, content)
|
pages = extract.extract_content(mime_type, content, chunk_size=chunk_size)
|
||||||
chunks = [
|
chunks = [
|
||||||
make_chunk(page, vector, metadata)
|
make_chunk(page, vector, metadata)
|
||||||
for page in pages
|
for page in pages
|
||||||
@ -266,13 +267,17 @@ def embed(
|
|||||||
return modality, chunks
|
return modality, chunks
|
||||||
|
|
||||||
|
|
||||||
def embed_image(file_path: pathlib.Path, texts: list[str]) -> Chunk:
|
def embed_image(
|
||||||
|
file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None
|
||||||
|
) -> Chunk:
|
||||||
image = Image.open(file_path)
|
image = Image.open(file_path)
|
||||||
mime_type = get_mimetype(image)
|
mime_type = get_mimetype(image)
|
||||||
if mime_type is None:
|
if mime_type is None:
|
||||||
raise ValueError("Unsupported image format")
|
raise ValueError("Unsupported image format")
|
||||||
|
|
||||||
vector = embed_mixed([image] + texts)[0]
|
vector = embed_mixed(
|
||||||
|
[image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS
|
||||||
|
)[0]
|
||||||
|
|
||||||
return Chunk(
|
return Chunk(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Generator, Sequence, TypedDict, cast
|
from typing import Any, Generator, NotRequired, Sequence, TypedDict, cast
|
||||||
|
|
||||||
import pymupdf # PyMuPDF
|
import pymupdf # PyMuPDF
|
||||||
import pypandoc
|
import pypandoc
|
||||||
@ -17,6 +17,8 @@ MulitmodalChunk = Image.Image | str
|
|||||||
class Page(TypedDict):
|
class Page(TypedDict):
|
||||||
contents: Sequence[MulitmodalChunk]
|
contents: Sequence[MulitmodalChunk]
|
||||||
metadata: dict[str, Any]
|
metadata: dict[str, Any]
|
||||||
|
# This is used to override the default chunk size for the page
|
||||||
|
chunk_size: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -108,22 +110,28 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
|||||||
return [{"contents": [cast(str, content)], "metadata": {}}]
|
return [{"contents": [cast(str, content)], "metadata": {}}]
|
||||||
|
|
||||||
|
|
||||||
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
def extract_content(
|
||||||
|
mime_type: str,
|
||||||
|
content: bytes | str | pathlib.Path,
|
||||||
|
chunk_size: int | None = None,
|
||||||
|
) -> list[Page]:
|
||||||
|
pages = []
|
||||||
logger.info(f"Extracting content from {mime_type}")
|
logger.info(f"Extracting content from {mime_type}")
|
||||||
if mime_type == "application/pdf":
|
if mime_type == "application/pdf":
|
||||||
return doc_to_images(content)
|
pages = doc_to_images(content)
|
||||||
if mime_type in [
|
elif mime_type in [
|
||||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
"application/msword",
|
"application/msword",
|
||||||
]:
|
]:
|
||||||
logger.info(f"Extracting content from {content}")
|
logger.info(f"Extracting content from {content}")
|
||||||
pages = extract_docx(content)
|
pages = extract_docx(content)
|
||||||
logger.info(f"Extracted {len(pages)} pages from {content}")
|
logger.info(f"Extracted {len(pages)} pages from {content}")
|
||||||
return pages
|
elif mime_type.startswith("text/"):
|
||||||
if mime_type.startswith("text/"):
|
pages = extract_text(content)
|
||||||
return extract_text(content)
|
elif mime_type.startswith("image/"):
|
||||||
if mime_type.startswith("image/"):
|
pages = extract_image(content)
|
||||||
return extract_image(content)
|
|
||||||
|
|
||||||
# Return empty list for unknown mime types
|
if chunk_size:
|
||||||
return []
|
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
|
||||||
|
|
||||||
|
return pages
|
||||||
|
@ -13,7 +13,7 @@ class Section:
|
|||||||
"""Represents a chapter or section in an ebook."""
|
"""Represents a chapter or section in an ebook."""
|
||||||
|
|
||||||
title: str
|
title: str
|
||||||
content: str
|
pages: list[str]
|
||||||
number: int | None = None
|
number: int | None = None
|
||||||
start_page: int | None = None
|
start_page: int | None = None
|
||||||
end_page: int | None = None
|
end_page: int | None = None
|
||||||
@ -74,13 +74,13 @@ def extract_epub_metadata(doc) -> dict[str, Any]:
|
|||||||
return {key: value for key, value in doc.metadata.items() if value}
|
return {key: value for key, value in doc.metadata.items() if value}
|
||||||
|
|
||||||
|
|
||||||
def get_pages(doc, start_page: int, end_page: int) -> str:
|
def get_pages(doc, start_page: int, end_page: int) -> list[str]:
|
||||||
pages = [
|
pages = [
|
||||||
doc[page_num].get_text()
|
doc[page_num].get_text()
|
||||||
for page_num in range(start_page, end_page + 1)
|
for page_num in range(start_page, end_page + 1)
|
||||||
if 0 <= page_num < doc.page_count
|
if 0 <= page_num < doc.page_count
|
||||||
]
|
]
|
||||||
return "\n".join(pages)
|
return pages
|
||||||
|
|
||||||
|
|
||||||
def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | None:
|
def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | None:
|
||||||
@ -96,7 +96,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section |
|
|||||||
if not next_item:
|
if not next_item:
|
||||||
return Section(
|
return Section(
|
||||||
title=name,
|
title=name,
|
||||||
content=get_pages(doc, page, doc.page_count),
|
pages=get_pages(doc, page, doc.page_count),
|
||||||
number=section_num,
|
number=section_num,
|
||||||
start_page=page,
|
start_page=page,
|
||||||
end_page=doc.page_count,
|
end_page=doc.page_count,
|
||||||
@ -110,7 +110,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section |
|
|||||||
last_page = next_item[2] - 1 if next_item else doc.page_count
|
last_page = next_item[2] - 1 if next_item else doc.page_count
|
||||||
return Section(
|
return Section(
|
||||||
title=name,
|
title=name,
|
||||||
content=get_pages(doc, page, last_page),
|
pages=get_pages(doc, page, last_page),
|
||||||
number=section_num,
|
number=section_num,
|
||||||
start_page=page,
|
start_page=page,
|
||||||
end_page=last_page,
|
end_page=last_page,
|
||||||
@ -125,7 +125,7 @@ def extract_sections(doc) -> list[Section]:
|
|||||||
return [
|
return [
|
||||||
Section(
|
Section(
|
||||||
title="Content",
|
title="Content",
|
||||||
content=doc.get_text(),
|
pages=get_pages(doc, 0, doc.page_count),
|
||||||
number=1,
|
number=1,
|
||||||
start_page=0,
|
start_page=0,
|
||||||
end_page=doc.page_count,
|
end_page=doc.page_count,
|
||||||
@ -169,7 +169,9 @@ def parse_ebook(file_path: str | Path) -> Ebook:
|
|||||||
sections = extract_sections(doc)
|
sections = extract_sections(doc)
|
||||||
full_content = ""
|
full_content = ""
|
||||||
if sections:
|
if sections:
|
||||||
full_content = "".join(section.content for section in sections)
|
full_content = "\n\n".join(
|
||||||
|
page for section in sections for page in section.pages
|
||||||
|
)
|
||||||
|
|
||||||
return Ebook(
|
return Ebook(
|
||||||
title=title,
|
title=title,
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, cast
|
from typing import Iterable, cast
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, settings
|
from memory.common import chunker, embedding, qdrant
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Book, BookSection
|
from memory.common.db.models import Book, BookSection
|
||||||
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
|
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
|
||||||
@ -44,7 +44,8 @@ def section_processor(
|
|||||||
level: int = 1,
|
level: int = 1,
|
||||||
parent_key: tuple[int, int | None] | None = None,
|
parent_key: tuple[int, int | None] | None = None,
|
||||||
):
|
):
|
||||||
if len(section.content.strip()) >= MIN_SECTION_LENGTH:
|
content = "\n\n".join(section.pages).strip()
|
||||||
|
if len(content) >= MIN_SECTION_LENGTH:
|
||||||
sha256 = hashlib.sha256(
|
sha256 = hashlib.sha256(
|
||||||
f"{book.id}:{section.title}:{section.start_page}".encode()
|
f"{book.id}:{section.title}:{section.start_page}".encode()
|
||||||
).digest()
|
).digest()
|
||||||
@ -57,10 +58,11 @@ def section_processor(
|
|||||||
start_page=section.start_page,
|
start_page=section.start_page,
|
||||||
end_page=section.end_page,
|
end_page=section.end_page,
|
||||||
parent_section_id=None, # Will be set after flush
|
parent_section_id=None, # Will be set after flush
|
||||||
content=section.content,
|
content=content,
|
||||||
sha256=sha256,
|
sha256=sha256,
|
||||||
modality="book",
|
modality="book",
|
||||||
tags=book.tags,
|
tags=book.tags,
|
||||||
|
pages=section.pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_sections.append(book_section)
|
all_sections.append(book_section)
|
||||||
@ -127,13 +129,28 @@ def embed_sections(all_sections: list[BookSection]) -> int:
|
|||||||
"""Embed all sections and return count of successfully embedded sections."""
|
"""Embed all sections and return count of successfully embedded sections."""
|
||||||
embedded_count = 0
|
embedded_count = 0
|
||||||
|
|
||||||
for section in all_sections:
|
def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]:
|
||||||
try:
|
|
||||||
_, chunks = embedding.embed(
|
_, chunks = embedding.embed(
|
||||||
"text/plain",
|
"text/plain",
|
||||||
cast(str, section.content),
|
text,
|
||||||
metadata=section.as_payload(),
|
metadata=metadata,
|
||||||
|
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||||
)
|
)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
for section in all_sections:
|
||||||
|
try:
|
||||||
|
section_chunks = embed_text(
|
||||||
|
cast(str, section.content), section.as_payload()
|
||||||
|
)
|
||||||
|
page_chunks = [
|
||||||
|
chunk
|
||||||
|
for i, page in enumerate(section.pages)
|
||||||
|
for chunk in embed_text(
|
||||||
|
page, section.as_payload() | {"page_number": i + section.start_page}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
chunks = section_chunks + page_chunks
|
||||||
|
|
||||||
if chunks:
|
if chunks:
|
||||||
section.chunks = chunks
|
section.chunks = chunks
|
||||||
|
@ -222,4 +222,7 @@ def qdrant():
|
|||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_voyage_client():
|
def mock_voyage_client():
|
||||||
with patch.object(voyageai, "Client", autospec=True) as mock_client:
|
with patch.object(voyageai, "Client", autospec=True) as mock_client:
|
||||||
yield mock_client()
|
client = mock_client()
|
||||||
|
client.embed.return_value.embeddings = [[0.1] * 1024]
|
||||||
|
client.multimodal_embed.return_value.embeddings = [[0.1] * 1024]
|
||||||
|
yield client
|
||||||
|
@ -83,20 +83,20 @@ def test_extract_epub_metadata(metadata_input, expected):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"start_page,end_page,expected_content",
|
"start_page,end_page,expected_content",
|
||||||
[
|
[
|
||||||
(0, 2, "Content of page 0\nContent of page 1\nContent of page 2"),
|
(0, 2, ["Content of page 0", "Content of page 1", "Content of page 2"]),
|
||||||
(3, 4, "Content of page 3\nContent of page 4"),
|
(3, 4, ["Content of page 3", "Content of page 4"]),
|
||||||
(4, 4, "Content of page 4"),
|
(4, 4, ["Content of page 4"]),
|
||||||
(
|
(
|
||||||
0,
|
0,
|
||||||
10,
|
10,
|
||||||
"Content of page 0\nContent of page 1\nContent of page 2\nContent of page 3\nContent of page 4",
|
[f"Content of page {i}" for i in range(5)],
|
||||||
), # Out of range
|
), # Out of range
|
||||||
(5, 10, ""), # Completely out of range
|
(5, 10, []), # Completely out of range
|
||||||
(3, 2, ""), # Invalid range (start > end)
|
(3, 2, []), # Invalid range (start > end)
|
||||||
(
|
(
|
||||||
-1,
|
-1,
|
||||||
2,
|
2,
|
||||||
"Content of page 0\nContent of page 1\nContent of page 2",
|
[f"Content of page {i}" for i in range(3)],
|
||||||
), # Negative start
|
), # Negative start
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -121,21 +121,21 @@ def test_extract_section_pages(mock_doc, mock_toc_items):
|
|||||||
number=1,
|
number=1,
|
||||||
start_page=0,
|
start_page=0,
|
||||||
end_page=2,
|
end_page=2,
|
||||||
content="Content of page 0\nContent of page 1\nContent of page 2",
|
pages=["Content of page 0", "Content of page 1", "Content of page 2"],
|
||||||
children=[
|
children=[
|
||||||
Section(
|
Section(
|
||||||
title="Section 1.1",
|
title="Section 1.1",
|
||||||
number=1,
|
number=1,
|
||||||
start_page=1,
|
start_page=1,
|
||||||
end_page=1,
|
end_page=1,
|
||||||
content="Content of page 1",
|
pages=["Content of page 1"],
|
||||||
),
|
),
|
||||||
Section(
|
Section(
|
||||||
title="Section 1.2",
|
title="Section 1.2",
|
||||||
number=2,
|
number=2,
|
||||||
start_page=2,
|
start_page=2,
|
||||||
end_page=2,
|
end_page=2,
|
||||||
content="Content of page 2",
|
pages=["Content of page 2"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -148,21 +148,21 @@ def test_extract_sections(mock_doc):
|
|||||||
number=1,
|
number=1,
|
||||||
start_page=0,
|
start_page=0,
|
||||||
end_page=2,
|
end_page=2,
|
||||||
content="Content of page 0\nContent of page 1\nContent of page 2",
|
pages=["Content of page 0", "Content of page 1", "Content of page 2"],
|
||||||
children=[
|
children=[
|
||||||
Section(
|
Section(
|
||||||
title="Section 1.1",
|
title="Section 1.1",
|
||||||
number=1,
|
number=1,
|
||||||
start_page=1,
|
start_page=1,
|
||||||
end_page=1,
|
end_page=1,
|
||||||
content="Content of page 1",
|
pages=["Content of page 1"],
|
||||||
),
|
),
|
||||||
Section(
|
Section(
|
||||||
title="Section 1.2",
|
title="Section 1.2",
|
||||||
number=2,
|
number=2,
|
||||||
start_page=2,
|
start_page=2,
|
||||||
end_page=2,
|
end_page=2,
|
||||||
content="Content of page 2",
|
pages=["Content of page 2"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
@ -171,14 +171,14 @@ def test_extract_sections(mock_doc):
|
|||||||
number=2,
|
number=2,
|
||||||
start_page=3,
|
start_page=3,
|
||||||
end_page=5,
|
end_page=5,
|
||||||
content="Content of page 3\nContent of page 4",
|
pages=["Content of page 3", "Content of page 4"],
|
||||||
children=[
|
children=[
|
||||||
Section(
|
Section(
|
||||||
title="Section 2.1",
|
title="Section 2.1",
|
||||||
number=1,
|
number=1,
|
||||||
start_page=4,
|
start_page=4,
|
||||||
end_page=5,
|
end_page=5,
|
||||||
content="Content of page 4",
|
pages=["Content of page 4"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
@ -195,7 +195,7 @@ def test_extract_sections_no_toc(mock_doc):
|
|||||||
number=1,
|
number=1,
|
||||||
start_page=0,
|
start_page=0,
|
||||||
end_page=5,
|
end_page=5,
|
||||||
content="Full document content",
|
pages=[f"Content of page {i}" for i in range(5)],
|
||||||
children=[],
|
children=[],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@ -376,9 +376,9 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
|
|||||||
|
|
||||||
# Create sections with specific content
|
# Create sections with specific content
|
||||||
section1 = MagicMock()
|
section1 = MagicMock()
|
||||||
section1.content = "Content of section 1"
|
section1.pages = ["Content of section 1"]
|
||||||
section2 = MagicMock()
|
section2 = MagicMock()
|
||||||
section2.content = "Content of section 2"
|
section2.pages = ["Content of section 2"]
|
||||||
|
|
||||||
# Mock extract_sections to return our sections
|
# Mock extract_sections to return our sections
|
||||||
with patch("memory.common.parsers.ebook.extract_sections") as mock_extract:
|
with patch("memory.common.parsers.ebook.extract_sections") as mock_extract:
|
||||||
@ -391,4 +391,4 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
|
|||||||
ebook = parse_ebook(test_file)
|
ebook = parse_ebook(test_file)
|
||||||
|
|
||||||
# Check the full content is concatenated correctly
|
# Check the full content is concatenated correctly
|
||||||
assert ebook.full_content == "Content of section 1Content of section 2"
|
assert ebook.full_content == "Content of section 1\n\nContent of section 2"
|
||||||
|
@ -17,22 +17,21 @@ def mock_ebook():
|
|||||||
sections=[
|
sections=[
|
||||||
Section(
|
Section(
|
||||||
title="Chapter 1",
|
title="Chapter 1",
|
||||||
content="This is the content of chapter 1. "
|
pages=["This is the content of chapter 1. " * 20],
|
||||||
* 20, # Make it long enough
|
|
||||||
number=1,
|
number=1,
|
||||||
start_page=1,
|
start_page=1,
|
||||||
end_page=10,
|
end_page=10,
|
||||||
children=[
|
children=[
|
||||||
Section(
|
Section(
|
||||||
title="Section 1.1",
|
title="Section 1.1",
|
||||||
content="This is section 1.1 content. " * 15,
|
pages=["This is section 1.1 content. " * 15],
|
||||||
number=1,
|
number=1,
|
||||||
start_page=1,
|
start_page=1,
|
||||||
end_page=5,
|
end_page=5,
|
||||||
),
|
),
|
||||||
Section(
|
Section(
|
||||||
title="Section 1.2",
|
title="Section 1.2",
|
||||||
content="This is section 1.2 content. " * 15,
|
pages=["This is section 1.2 content. " * 15],
|
||||||
number=2,
|
number=2,
|
||||||
start_page=6,
|
start_page=6,
|
||||||
end_page=10,
|
end_page=10,
|
||||||
@ -41,7 +40,7 @@ def mock_ebook():
|
|||||||
),
|
),
|
||||||
Section(
|
Section(
|
||||||
title="Chapter 2",
|
title="Chapter 2",
|
||||||
content="This is the content of chapter 2. " * 20,
|
pages=["This is the content of chapter 2. " * 20],
|
||||||
number=2,
|
number=2,
|
||||||
start_page=11,
|
start_page=11,
|
||||||
end_page=20,
|
end_page=20,
|
||||||
@ -52,7 +51,7 @@ def mock_ebook():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture
|
||||||
def mock_embedding():
|
def mock_embedding():
|
||||||
"""Mock the embedding function to return dummy vectors."""
|
"""Mock the embedding function to return dummy vectors."""
|
||||||
with patch("memory.workers.tasks.ebook.embedding.embed") as mock:
|
with patch("memory.workers.tasks.ebook.embedding.embed") as mock:
|
||||||
@ -316,3 +315,55 @@ def test_sync_book_file_not_found():
|
|||||||
"""Test handling of missing files."""
|
"""Test handling of missing files."""
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
ebook.sync_book("/nonexistent/file.epub")
|
ebook.sync_book("/nonexistent/file.epub")
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
||||||
|
"""Test that book sections with large pages are passed whole to the embedding function."""
|
||||||
|
# Create a test book first
|
||||||
|
book = Book(
|
||||||
|
title="Test Book",
|
||||||
|
author="Test Author",
|
||||||
|
file_path="/test/path",
|
||||||
|
)
|
||||||
|
db_session.add(book)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Create large content that exceeds 1000 tokens (4000+ characters)
|
||||||
|
large_section_content = "This is a very long section content. " * 120 # ~4440 chars
|
||||||
|
large_page_1 = "This is page 1 with lots of content. " * 120 # ~4440 chars
|
||||||
|
large_page_2 = "This is page 2 with lots of content. " * 120 # ~4440 chars
|
||||||
|
|
||||||
|
# Create test sections with large content and pages
|
||||||
|
sections = [
|
||||||
|
BookSection(
|
||||||
|
book_id=book.id,
|
||||||
|
section_title="Test Section",
|
||||||
|
section_number=1,
|
||||||
|
section_level=1,
|
||||||
|
start_page=1,
|
||||||
|
end_page=10,
|
||||||
|
content=large_section_content,
|
||||||
|
sha256=b"test_hash",
|
||||||
|
modality="book",
|
||||||
|
tags=["book"],
|
||||||
|
pages=[large_page_1, large_page_2],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
db_session.add_all(sections)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
ebook.embed_sections(sections)
|
||||||
|
|
||||||
|
# Verify that the voyage client was called with the full large content
|
||||||
|
# Should be called 3 times: once for section content, twice for pages
|
||||||
|
assert mock_voyage_client.embed.call_count == 3
|
||||||
|
|
||||||
|
# Check that the full content was passed to the embedding function
|
||||||
|
calls = mock_voyage_client.embed.call_args_list
|
||||||
|
texts = [c[0][0] for c in calls]
|
||||||
|
assert texts == [
|
||||||
|
[large_section_content.strip()],
|
||||||
|
[large_page_1.strip()],
|
||||||
|
[large_page_2.strip()],
|
||||||
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user