mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
proper chunk sizes for books
This commit is contained in:
parent
eb69221999
commit
e8070a3557
@ -6,8 +6,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Chunking configuration
|
||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
||||
OVERLAP_TOKENS = 200 # Default overlap between chunks
|
||||
EMBEDDING_MAX_TOKENS = 32000 # VoyageAI max context window
|
||||
DEFAULT_CHUNK_TOKENS = 512 # Optimal chunk size for semantic search
|
||||
OVERLAP_TOKENS = 50 # Default overlap between chunks (10% of chunk size)
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
@ -23,7 +24,9 @@ def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
def yield_word_chunks(
|
||||
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
||||
) -> Iterable[str]:
|
||||
words = text.split()
|
||||
if not words:
|
||||
return
|
||||
@ -40,7 +43,7 @@ def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
yield current
|
||||
|
||||
|
||||
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[str]:
|
||||
"""
|
||||
Yield text spans in priority order: paragraphs, sentences, words.
|
||||
Each span is guaranteed to be under max_tokens.
|
||||
@ -76,14 +79,16 @@ def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
yield chunk
|
||||
|
||||
|
||||
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
|
||||
def chunk_text(
|
||||
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS, overlap: int = OVERLAP_TOKENS
|
||||
) -> Iterable[str]:
|
||||
"""
|
||||
Split text into chunks respecting semantic boundaries while staying within token limits.
|
||||
|
||||
Args:
|
||||
text: The text to chunk
|
||||
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
|
||||
overlap: Number of tokens to overlap between chunks (default: 200)
|
||||
max_tokens: Maximum tokens per chunk (default: 512 for optimal semantic search)
|
||||
overlap: Number of tokens to overlap between chunks (default: 50)
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
@ -111,9 +116,7 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
|
||||
|
||||
overlap_text = current[-overlap_chars:]
|
||||
clean_break = max(
|
||||
overlap_text.rfind(". "),
|
||||
overlap_text.rfind("! "),
|
||||
overlap_text.rfind("? ")
|
||||
overlap_text.rfind(". "), overlap_text.rfind("! "), overlap_text.rfind("? ")
|
||||
)
|
||||
|
||||
if clean_break < 0:
|
||||
|
@ -481,6 +481,7 @@ class BookSection(SourceItem):
|
||||
backref="children",
|
||||
foreign_keys=[parent_section_id],
|
||||
)
|
||||
pages: list[str] = []
|
||||
|
||||
__mapper_args__ = {"polymorphic_identity": "book_section"}
|
||||
__table_args__ = (
|
||||
|
@ -7,18 +7,12 @@ import voyageai
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import extract, settings
|
||||
from memory.common.chunker import chunk_text
|
||||
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
|
||||
from memory.common.db.models import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Chunking configuration
|
||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
||||
OVERLAP_TOKENS = 200 # Default overlap between chunks
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||
Vector = list[float]
|
||||
|
||||
@ -150,12 +144,13 @@ def embed_text(
|
||||
texts: list[str],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
chunks = [
|
||||
c
|
||||
for text in texts
|
||||
if isinstance(text, str)
|
||||
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
|
||||
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
|
||||
if c.strip()
|
||||
]
|
||||
if not chunks:
|
||||
@ -179,11 +174,12 @@ def embed_mixed(
|
||||
items: list[extract.MulitmodalChunk],
|
||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
||||
if isinstance(item, str):
|
||||
return [
|
||||
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
|
||||
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
|
||||
]
|
||||
return [item]
|
||||
|
||||
@ -193,13 +189,17 @@ def embed_mixed(
|
||||
|
||||
def embed_page(page: extract.Page) -> list[Vector]:
|
||||
contents = page["contents"]
|
||||
chunk_size = page.get("chunk_size", DEFAULT_CHUNK_TOKENS)
|
||||
if all(isinstance(c, str) for c in contents):
|
||||
return embed_text(
|
||||
cast(list[str], contents), model=settings.TEXT_EMBEDDING_MODEL
|
||||
cast(list[str], contents),
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
return embed_mixed(
|
||||
cast(list[extract.MulitmodalChunk], contents),
|
||||
model=settings.MIXED_EMBEDDING_MODEL,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
|
||||
@ -255,9 +255,10 @@ def embed(
|
||||
mime_type: str,
|
||||
content: bytes | str | pathlib.Path,
|
||||
metadata: dict[str, Any] = {},
|
||||
chunk_size: int | None = None,
|
||||
) -> tuple[str, list[Chunk]]:
|
||||
modality = get_modality(mime_type)
|
||||
pages = extract.extract_content(mime_type, content)
|
||||
pages = extract.extract_content(mime_type, content, chunk_size=chunk_size)
|
||||
chunks = [
|
||||
make_chunk(page, vector, metadata)
|
||||
for page in pages
|
||||
@ -266,13 +267,17 @@ def embed(
|
||||
return modality, chunks
|
||||
|
||||
|
||||
def embed_image(file_path: pathlib.Path, texts: list[str]) -> Chunk:
|
||||
def embed_image(
|
||||
file_path: pathlib.Path, texts: list[str], chunk_size: int | None = None
|
||||
) -> Chunk:
|
||||
image = Image.open(file_path)
|
||||
mime_type = get_mimetype(image)
|
||||
if mime_type is None:
|
||||
raise ValueError("Unsupported image format")
|
||||
|
||||
vector = embed_mixed([image] + texts)[0]
|
||||
vector = embed_mixed(
|
||||
[image] + texts, chunk_size=chunk_size or DEFAULT_CHUNK_TOKENS
|
||||
)[0]
|
||||
|
||||
return Chunk(
|
||||
id=str(uuid.uuid4()),
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import pathlib
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Sequence, TypedDict, cast
|
||||
from typing import Any, Generator, NotRequired, Sequence, TypedDict, cast
|
||||
|
||||
import pymupdf # PyMuPDF
|
||||
import pypandoc
|
||||
@ -17,6 +17,8 @@ MulitmodalChunk = Image.Image | str
|
||||
class Page(TypedDict):
|
||||
contents: Sequence[MulitmodalChunk]
|
||||
metadata: dict[str, Any]
|
||||
# This is used to override the default chunk size for the page
|
||||
chunk_size: NotRequired[int]
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -108,22 +110,28 @@ def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
return [{"contents": [cast(str, content)], "metadata": {}}]
|
||||
|
||||
|
||||
def extract_content(mime_type: str, content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
def extract_content(
|
||||
mime_type: str,
|
||||
content: bytes | str | pathlib.Path,
|
||||
chunk_size: int | None = None,
|
||||
) -> list[Page]:
|
||||
pages = []
|
||||
logger.info(f"Extracting content from {mime_type}")
|
||||
if mime_type == "application/pdf":
|
||||
return doc_to_images(content)
|
||||
if mime_type in [
|
||||
pages = doc_to_images(content)
|
||||
elif mime_type in [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/msword",
|
||||
]:
|
||||
logger.info(f"Extracting content from {content}")
|
||||
pages = extract_docx(content)
|
||||
logger.info(f"Extracted {len(pages)} pages from {content}")
|
||||
return pages
|
||||
if mime_type.startswith("text/"):
|
||||
return extract_text(content)
|
||||
if mime_type.startswith("image/"):
|
||||
return extract_image(content)
|
||||
elif mime_type.startswith("text/"):
|
||||
pages = extract_text(content)
|
||||
elif mime_type.startswith("image/"):
|
||||
pages = extract_image(content)
|
||||
|
||||
# Return empty list for unknown mime types
|
||||
return []
|
||||
if chunk_size:
|
||||
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
|
||||
|
||||
return pages
|
||||
|
@ -13,7 +13,7 @@ class Section:
|
||||
"""Represents a chapter or section in an ebook."""
|
||||
|
||||
title: str
|
||||
content: str
|
||||
pages: list[str]
|
||||
number: int | None = None
|
||||
start_page: int | None = None
|
||||
end_page: int | None = None
|
||||
@ -74,13 +74,13 @@ def extract_epub_metadata(doc) -> dict[str, Any]:
|
||||
return {key: value for key, value in doc.metadata.items() if value}
|
||||
|
||||
|
||||
def get_pages(doc, start_page: int, end_page: int) -> str:
|
||||
def get_pages(doc, start_page: int, end_page: int) -> list[str]:
|
||||
pages = [
|
||||
doc[page_num].get_text()
|
||||
for page_num in range(start_page, end_page + 1)
|
||||
if 0 <= page_num < doc.page_count
|
||||
]
|
||||
return "\n".join(pages)
|
||||
return pages
|
||||
|
||||
|
||||
def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section | None:
|
||||
@ -96,7 +96,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section |
|
||||
if not next_item:
|
||||
return Section(
|
||||
title=name,
|
||||
content=get_pages(doc, page, doc.page_count),
|
||||
pages=get_pages(doc, page, doc.page_count),
|
||||
number=section_num,
|
||||
start_page=page,
|
||||
end_page=doc.page_count,
|
||||
@ -110,7 +110,7 @@ def extract_section_pages(doc, toc: Peekable, section_num: int = 1) -> Section |
|
||||
last_page = next_item[2] - 1 if next_item else doc.page_count
|
||||
return Section(
|
||||
title=name,
|
||||
content=get_pages(doc, page, last_page),
|
||||
pages=get_pages(doc, page, last_page),
|
||||
number=section_num,
|
||||
start_page=page,
|
||||
end_page=last_page,
|
||||
@ -125,7 +125,7 @@ def extract_sections(doc) -> list[Section]:
|
||||
return [
|
||||
Section(
|
||||
title="Content",
|
||||
content=doc.get_text(),
|
||||
pages=get_pages(doc, 0, doc.page_count),
|
||||
number=1,
|
||||
start_page=0,
|
||||
end_page=doc.page_count,
|
||||
@ -169,7 +169,9 @@ def parse_ebook(file_path: str | Path) -> Ebook:
|
||||
sections = extract_sections(doc)
|
||||
full_content = ""
|
||||
if sections:
|
||||
full_content = "".join(section.content for section in sections)
|
||||
full_content = "\n\n".join(
|
||||
page for section in sections for page in section.pages
|
||||
)
|
||||
|
||||
return Ebook(
|
||||
title=title,
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable, cast
|
||||
|
||||
from memory.common import embedding, qdrant, settings
|
||||
from memory.common import chunker, embedding, qdrant
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Book, BookSection
|
||||
from memory.common.parsers.ebook import Ebook, parse_ebook, Section
|
||||
@ -44,7 +44,8 @@ def section_processor(
|
||||
level: int = 1,
|
||||
parent_key: tuple[int, int | None] | None = None,
|
||||
):
|
||||
if len(section.content.strip()) >= MIN_SECTION_LENGTH:
|
||||
content = "\n\n".join(section.pages).strip()
|
||||
if len(content) >= MIN_SECTION_LENGTH:
|
||||
sha256 = hashlib.sha256(
|
||||
f"{book.id}:{section.title}:{section.start_page}".encode()
|
||||
).digest()
|
||||
@ -57,10 +58,11 @@ def section_processor(
|
||||
start_page=section.start_page,
|
||||
end_page=section.end_page,
|
||||
parent_section_id=None, # Will be set after flush
|
||||
content=section.content,
|
||||
content=content,
|
||||
sha256=sha256,
|
||||
modality="book",
|
||||
tags=book.tags,
|
||||
pages=section.pages,
|
||||
)
|
||||
|
||||
all_sections.append(book_section)
|
||||
@ -127,13 +129,28 @@ def embed_sections(all_sections: list[BookSection]) -> int:
|
||||
"""Embed all sections and return count of successfully embedded sections."""
|
||||
embedded_count = 0
|
||||
|
||||
def embed_text(text: str, metadata: dict) -> list[embedding.Chunk]:
|
||||
_, chunks = embedding.embed(
|
||||
"text/plain",
|
||||
text,
|
||||
metadata=metadata,
|
||||
chunk_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||
)
|
||||
return chunks
|
||||
|
||||
for section in all_sections:
|
||||
try:
|
||||
_, chunks = embedding.embed(
|
||||
"text/plain",
|
||||
cast(str, section.content),
|
||||
metadata=section.as_payload(),
|
||||
section_chunks = embed_text(
|
||||
cast(str, section.content), section.as_payload()
|
||||
)
|
||||
page_chunks = [
|
||||
chunk
|
||||
for i, page in enumerate(section.pages)
|
||||
for chunk in embed_text(
|
||||
page, section.as_payload() | {"page_number": i + section.start_page}
|
||||
)
|
||||
]
|
||||
chunks = section_chunks + page_chunks
|
||||
|
||||
if chunks:
|
||||
section.chunks = chunks
|
||||
|
@ -222,4 +222,7 @@ def qdrant():
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_voyage_client():
|
||||
with patch.object(voyageai, "Client", autospec=True) as mock_client:
|
||||
yield mock_client()
|
||||
client = mock_client()
|
||||
client.embed.return_value.embeddings = [[0.1] * 1024]
|
||||
client.multimodal_embed.return_value.embeddings = [[0.1] * 1024]
|
||||
yield client
|
||||
|
@ -83,20 +83,20 @@ def test_extract_epub_metadata(metadata_input, expected):
|
||||
@pytest.mark.parametrize(
|
||||
"start_page,end_page,expected_content",
|
||||
[
|
||||
(0, 2, "Content of page 0\nContent of page 1\nContent of page 2"),
|
||||
(3, 4, "Content of page 3\nContent of page 4"),
|
||||
(4, 4, "Content of page 4"),
|
||||
(0, 2, ["Content of page 0", "Content of page 1", "Content of page 2"]),
|
||||
(3, 4, ["Content of page 3", "Content of page 4"]),
|
||||
(4, 4, ["Content of page 4"]),
|
||||
(
|
||||
0,
|
||||
10,
|
||||
"Content of page 0\nContent of page 1\nContent of page 2\nContent of page 3\nContent of page 4",
|
||||
[f"Content of page {i}" for i in range(5)],
|
||||
), # Out of range
|
||||
(5, 10, ""), # Completely out of range
|
||||
(3, 2, ""), # Invalid range (start > end)
|
||||
(5, 10, []), # Completely out of range
|
||||
(3, 2, []), # Invalid range (start > end)
|
||||
(
|
||||
-1,
|
||||
2,
|
||||
"Content of page 0\nContent of page 1\nContent of page 2",
|
||||
[f"Content of page {i}" for i in range(3)],
|
||||
), # Negative start
|
||||
],
|
||||
)
|
||||
@ -121,21 +121,21 @@ def test_extract_section_pages(mock_doc, mock_toc_items):
|
||||
number=1,
|
||||
start_page=0,
|
||||
end_page=2,
|
||||
content="Content of page 0\nContent of page 1\nContent of page 2",
|
||||
pages=["Content of page 0", "Content of page 1", "Content of page 2"],
|
||||
children=[
|
||||
Section(
|
||||
title="Section 1.1",
|
||||
number=1,
|
||||
start_page=1,
|
||||
end_page=1,
|
||||
content="Content of page 1",
|
||||
pages=["Content of page 1"],
|
||||
),
|
||||
Section(
|
||||
title="Section 1.2",
|
||||
number=2,
|
||||
start_page=2,
|
||||
end_page=2,
|
||||
content="Content of page 2",
|
||||
pages=["Content of page 2"],
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -148,21 +148,21 @@ def test_extract_sections(mock_doc):
|
||||
number=1,
|
||||
start_page=0,
|
||||
end_page=2,
|
||||
content="Content of page 0\nContent of page 1\nContent of page 2",
|
||||
pages=["Content of page 0", "Content of page 1", "Content of page 2"],
|
||||
children=[
|
||||
Section(
|
||||
title="Section 1.1",
|
||||
number=1,
|
||||
start_page=1,
|
||||
end_page=1,
|
||||
content="Content of page 1",
|
||||
pages=["Content of page 1"],
|
||||
),
|
||||
Section(
|
||||
title="Section 1.2",
|
||||
number=2,
|
||||
start_page=2,
|
||||
end_page=2,
|
||||
content="Content of page 2",
|
||||
pages=["Content of page 2"],
|
||||
),
|
||||
],
|
||||
),
|
||||
@ -171,14 +171,14 @@ def test_extract_sections(mock_doc):
|
||||
number=2,
|
||||
start_page=3,
|
||||
end_page=5,
|
||||
content="Content of page 3\nContent of page 4",
|
||||
pages=["Content of page 3", "Content of page 4"],
|
||||
children=[
|
||||
Section(
|
||||
title="Section 2.1",
|
||||
number=1,
|
||||
start_page=4,
|
||||
end_page=5,
|
||||
content="Content of page 4",
|
||||
pages=["Content of page 4"],
|
||||
),
|
||||
],
|
||||
),
|
||||
@ -195,7 +195,7 @@ def test_extract_sections_no_toc(mock_doc):
|
||||
number=1,
|
||||
start_page=0,
|
||||
end_page=5,
|
||||
content="Full document content",
|
||||
pages=[f"Content of page {i}" for i in range(5)],
|
||||
children=[],
|
||||
),
|
||||
]
|
||||
@ -376,9 +376,9 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
|
||||
|
||||
# Create sections with specific content
|
||||
section1 = MagicMock()
|
||||
section1.content = "Content of section 1"
|
||||
section1.pages = ["Content of section 1"]
|
||||
section2 = MagicMock()
|
||||
section2.content = "Content of section 2"
|
||||
section2.pages = ["Content of section 2"]
|
||||
|
||||
# Mock extract_sections to return our sections
|
||||
with patch("memory.common.parsers.ebook.extract_sections") as mock_extract:
|
||||
@ -391,4 +391,4 @@ def test_parse_ebook_full_content_generation(mock_open, mock_doc, tmp_path):
|
||||
ebook = parse_ebook(test_file)
|
||||
|
||||
# Check the full content is concatenated correctly
|
||||
assert ebook.full_content == "Content of section 1Content of section 2"
|
||||
assert ebook.full_content == "Content of section 1\n\nContent of section 2"
|
||||
|
@ -17,22 +17,21 @@ def mock_ebook():
|
||||
sections=[
|
||||
Section(
|
||||
title="Chapter 1",
|
||||
content="This is the content of chapter 1. "
|
||||
* 20, # Make it long enough
|
||||
pages=["This is the content of chapter 1. " * 20],
|
||||
number=1,
|
||||
start_page=1,
|
||||
end_page=10,
|
||||
children=[
|
||||
Section(
|
||||
title="Section 1.1",
|
||||
content="This is section 1.1 content. " * 15,
|
||||
pages=["This is section 1.1 content. " * 15],
|
||||
number=1,
|
||||
start_page=1,
|
||||
end_page=5,
|
||||
),
|
||||
Section(
|
||||
title="Section 1.2",
|
||||
content="This is section 1.2 content. " * 15,
|
||||
pages=["This is section 1.2 content. " * 15],
|
||||
number=2,
|
||||
start_page=6,
|
||||
end_page=10,
|
||||
@ -41,7 +40,7 @@ def mock_ebook():
|
||||
),
|
||||
Section(
|
||||
title="Chapter 2",
|
||||
content="This is the content of chapter 2. " * 20,
|
||||
pages=["This is the content of chapter 2. " * 20],
|
||||
number=2,
|
||||
start_page=11,
|
||||
end_page=20,
|
||||
@ -52,7 +51,7 @@ def mock_ebook():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@pytest.fixture
|
||||
def mock_embedding():
|
||||
"""Mock the embedding function to return dummy vectors."""
|
||||
with patch("memory.workers.tasks.ebook.embedding.embed") as mock:
|
||||
@ -316,3 +315,55 @@ def test_sync_book_file_not_found():
|
||||
"""Test handling of missing files."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
ebook.sync_book("/nonexistent/file.epub")
|
||||
|
||||
|
||||
def test_embed_sections_uses_correct_chunk_size(db_session, mock_voyage_client):
|
||||
"""Test that book sections with large pages are passed whole to the embedding function."""
|
||||
# Create a test book first
|
||||
book = Book(
|
||||
title="Test Book",
|
||||
author="Test Author",
|
||||
file_path="/test/path",
|
||||
)
|
||||
db_session.add(book)
|
||||
db_session.flush()
|
||||
|
||||
# Create large content that exceeds 1000 tokens (4000+ characters)
|
||||
large_section_content = "This is a very long section content. " * 120 # ~4440 chars
|
||||
large_page_1 = "This is page 1 with lots of content. " * 120 # ~4440 chars
|
||||
large_page_2 = "This is page 2 with lots of content. " * 120 # ~4440 chars
|
||||
|
||||
# Create test sections with large content and pages
|
||||
sections = [
|
||||
BookSection(
|
||||
book_id=book.id,
|
||||
section_title="Test Section",
|
||||
section_number=1,
|
||||
section_level=1,
|
||||
start_page=1,
|
||||
end_page=10,
|
||||
content=large_section_content,
|
||||
sha256=b"test_hash",
|
||||
modality="book",
|
||||
tags=["book"],
|
||||
pages=[large_page_1, large_page_2],
|
||||
)
|
||||
]
|
||||
|
||||
db_session.add_all(sections)
|
||||
db_session.flush()
|
||||
|
||||
ebook.embed_sections(sections)
|
||||
|
||||
# Verify that the voyage client was called with the full large content
|
||||
# Should be called 3 times: once for section content, twice for pages
|
||||
assert mock_voyage_client.embed.call_count == 3
|
||||
|
||||
# Check that the full content was passed to the embedding function
|
||||
calls = mock_voyage_client.embed.call_args_list
|
||||
texts = [c[0][0] for c in calls]
|
||||
assert texts == [
|
||||
[large_section_content.strip()],
|
||||
[large_page_1.strip()],
|
||||
[large_page_2.strip()],
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user