add chunker

This commit is contained in:
Daniel O'Connell 2025-04-28 14:17:02 +02:00
parent 3dca666d08
commit 2d2f37536a
6 changed files with 417 additions and 21 deletions

View File

@ -1,5 +1,13 @@
import pathlib
from typing import Literal, TypedDict
from typing import Literal, TypedDict, Iterable
import voyageai
import re
# Chunking configuration
MAX_TOKENS = 32000 # VoyageAI max context window
OVERLAP_TOKENS = 200 # Default overlap between chunks
CHARS_PER_TOKEN = 4
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
@ -12,13 +20,13 @@ class Collection(TypedDict):
DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1536, "distance": "Cosine"},
"chat": {"dimension": 1536, "distance": "Cosine"},
"git": {"dimension": 1536, "distance": "Cosine"},
"mail": {"dimension": 1024, "distance": "Cosine"},
"chat": {"dimension": 1024, "distance": "Cosine"},
"git": {"dimension": 1024, "distance": "Cosine"},
"photo": {"dimension": 512, "distance": "Cosine"},
"book": {"dimension": 1536, "distance": "Cosine"},
"blog": {"dimension": 1536, "distance": "Cosine"},
"doc": {"dimension": 1536, "distance": "Cosine"},
"book": {"dimension": 1024, "distance": "Cosine"},
"blog": {"dimension": 1024, "distance": "Cosine"},
"doc": {"dimension": 1024, "distance": "Cosine"},
}
TYPES = {
@ -40,21 +48,138 @@ def get_type(mime_type: str) -> str:
return "unknown"
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
# Regex for sentence splitting
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
words = text.split()
if not words:
return
current = ""
for word in words:
new_chunk = f"{current} {word}".strip()
if current and approx_token_count(new_chunk) > max_tokens:
yield current
current = word
else:
current = new_chunk
if current: # Only yield non-empty final chunk
yield current
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
"""
Embed a text using OpenAI's API.
Yield text spans in priority order: paragraphs, sentences, words.
Each span is guaranteed to be under max_tokens.
Args:
text: The text to split
max_tokens: Maximum tokens per chunk
Yields:
Spans of text that fit within token limits
"""
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
# Return early for empty text
if not text.strip():
return
for paragraph in text.split("\n\n"):
if not paragraph.strip():
continue
if approx_token_count(paragraph) <= max_tokens:
yield paragraph
continue
for sentence in _SENT_SPLIT_RE.split(paragraph):
if not sentence.strip():
continue
if approx_token_count(sentence) <= max_tokens:
yield sentence
continue
for chunk in yield_word_chunks(sentence, max_tokens):
yield chunk
def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> list[str]:
"""
Embed a file using OpenAI's API.
Split text into chunks respecting semantic boundaries while staying within token limits.
Args:
text: The text to chunk
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
overlap: Number of tokens to overlap between chunks (default: 200)
Returns:
List of text chunks
"""
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
text = text.strip()
if not text:
return
if approx_token_count(text) <= max_tokens:
yield text
return
overlap_chars = overlap * CHARS_PER_TOKEN
current = ""
for span in yield_spans(text, max_tokens):
current = f"{current} {span}".strip()
if approx_token_count(current) < max_tokens:
continue
if overlap <= 0:
yield current
current = ""
continue
overlap_text = current[-overlap_chars:]
clean_break = max(
overlap_text.rfind(". "),
overlap_text.rfind("! "),
overlap_text.rfind("? ")
)
print(f"clean_break: {clean_break}")
print(f"overlap_text: {overlap_text}")
print(f"current: {current}")
if clean_break < 0:
yield current
current = ""
continue
break_offset = -overlap_chars + clean_break + 1
chunk = current[break_offset:].strip()
print(f"chunk: {chunk}")
print(f"current: {current}")
yield current
current = chunk
if current:
yield current.strip()
def embed(mime_type: str, content: bytes | str | pathlib.Path, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> tuple[str, list[float]]:
collection = get_type(mime_type)
def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
vo = voyageai.Client()
chunks = chunk_text(text, MAX_TOKENS)
return [vo.embed(chunk, model=model) for chunk in chunks]
return collection, [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
return embed_text(file_path.read_text(), model, n_dimensions)
def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> tuple[str, list[Vector]]:
if isinstance(content, bytes):
content = content.decode("utf-8")
return get_type(mime_type), embed_text(content, model, n_dimensions)

View File

@ -12,9 +12,9 @@ from typing import Generator, Callable, TypedDict, Literal
import pathlib
from sqlalchemy.orm import Session
from collections import defaultdict
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common import settings, embedding
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common.qdrant import get_qdrant_client, upsert_vectors
logger = logging.getLogger(__name__)

View File

@ -12,8 +12,8 @@ from sqlalchemy.orm import sessionmaker
from testcontainers.qdrant import QdrantContainer
from memory.common import settings
from memory.common.qdrant import initialize_collections
from tests.providers.email_provider import MockEmailProvider
from memory.workers.qdrant import initialize_collections
def get_test_db_name() -> str:

View File

@ -0,0 +1,272 @@
import pytest
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, MAX_TOKENS, approx_token_count
@pytest.mark.parametrize(
"text, expected",
[
("", []),
("hello", ["hello"]),
("This is a simple sentence", ["This is a simple sentence"]),
("word1 word2", ["word1 word2"]),
(" ", []), # Just spaces
("\n\t ", []), # Whitespace characters
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
]
)
def test_yield_word_chunk_basic_behavior(text, expected):
"""Test basic behavior of yield_word_chunks with various inputs"""
assert list(yield_word_chunks(text)) == expected
@pytest.mark.parametrize(
"text, expected",
[
(
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
),
(
"supercalifragilisticexpialidocious",
["supercalifragilisticexpialidocious"],
)
]
)
def test_yield_word_chunk_long_text(text, expected):
"""Test chunking with long text that exceeds token limits"""
assert list(yield_word_chunks(text, max_tokens=10)) == expected
def test_yield_word_chunk_single_long_word():
"""Test behavior with a single word longer than the token limit"""
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
chunks = list(yield_word_chunks(long_word, max_tokens))
# With our changes, this should be a single chunk
assert len(chunks) == 1
assert chunks[0] == long_word
def test_yield_word_chunk_small_token_limit():
"""Test with a very small max_tokens value to force chunking"""
text = "one two three four five"
max_tokens = 1 # Very small to force chunking after each word
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
@pytest.mark.parametrize(
"text, max_tokens, expected_chunks",
[
# Empty text
("", 10, []),
# Text below token limit
("hello world", 10, ["hello world"]),
# Text right at token limit
(
"word1 word2", # 11 chars with space
3, # 12 chars limit
["word1 word2"]
),
# Text just over token limit should split
(
"word1 word2 word3", # 17 chars with spaces
4, # 16 chars limit
["word1 word2 word3"]
),
# Each word exactly at token limit
(
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
1, # 1 token limit (4 chars)
["aaaa", "bbbb", "cccc"]
),
]
)
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
"""Test different combinations of text and token limits"""
assert list(yield_word_chunks(text, max_tokens)) == expected_chunks
def test_yield_word_chunk_real_world_example():
"""Test with a realistic text example"""
text = (
"The yield_word_chunks function splits text into chunks based on word boundaries. "
"It tries to maximize chunk size while staying under the specified token limit. "
"This behavior is essential for processing large documents efficiently."
)
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
assert list(yield_word_chunks(text, max_tokens)) == [
'The yield_word_chunks function splits text',
'into chunks based on word boundaries. It',
'tries to maximize chunk size while staying',
'under the specified token limit. This',
'behavior is essential for processing large',
'documents efficiently.',
]
# Tests for yield_spans function
@pytest.mark.parametrize(
"text, expected",
[
("", []), # Empty text should yield nothing
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
(" ", []), # Just whitespace
]
)
def test_yield_spans_basic_behavior(text, expected):
"""Test basic behavior of yield_spans with various inputs"""
assert list(yield_spans(text)) == expected
def test_yield_spans_paragraphs():
"""Test splitting by paragraphs"""
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
expected = ["Paragraph one.", "Paragraph two.", "Paragraph three."]
assert list(yield_spans(text)) == expected
def test_yield_spans_sentences():
"""Test splitting by sentences when paragraphs exceed token limit"""
# Create a paragraph that exceeds token limit but sentences are within limit
max_tokens = 5 # 20 chars with CHARS_PER_TOKEN = 4
sentence1 = "Short sentence one." # ~20 chars
sentence2 = "Another short sentence." # ~24 chars
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
# Function should now preserve punctuation
expected = ["Short sentence one.", "Another short sentence."]
assert list(yield_spans(text, max_tokens)) == expected
def test_yield_spans_words():
"""Test splitting by words when sentences exceed token limit"""
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
long_sentence = "This sentence has several words and needs word-level chunking."
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
def test_yield_spans_complex_document():
"""Test with a document containing multiple paragraphs and sentences"""
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
text = (
"Paragraph one with a short sentence. And another sentence that should be split.\n\n"
"Paragraph two is also here. It has multiple sentences. Some are short. "
"This one is longer and might need word splitting depending on the limit.\n\n"
"Final short paragraph."
)
assert list(yield_spans(text, max_tokens)) == [
"Paragraph one with a short sentence.",
"And another sentence that should be split.",
"Paragraph two is also here.",
"It has multiple sentences.",
"Some are short.",
"This one is longer and might need word",
"splitting depending on the limit.",
"Final short paragraph."
]
def test_yield_spans_very_long_word():
"""Test with a word that exceeds the token limit"""
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
assert list(yield_spans(long_word, max_tokens)) == [long_word]
def test_yield_spans_with_punctuation():
"""Test sentence splitting with various punctuation"""
text = "First sentence! Second sentence? Third sentence."
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
def test_yield_spans_edge_cases():
"""Test edge cases like empty paragraphs, single character paragraphs"""
text = "\n\nA\n\n\n\nB\n\n"
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
@pytest.mark.parametrize(
"text, expected",
[
("", []), # Empty text
("Short text", ["Short text"]), # Text below token limit
(" ", []), # Just whitespace
]
)
def test_chunk_text_basic_behavior(text, expected):
"""Test basic behavior of chunk_text with various inputs"""
assert list(chunk_text(text)) == expected
def test_chunk_text_single_paragraph():
"""Test chunking a single paragraph that fits within token limit"""
text = "This is a simple paragraph that should fit in one chunk."
assert list(chunk_text(text, max_tokens=20)) == [text]
def test_chunk_text_multi_paragraph():
"""Test chunking multiple paragraphs"""
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
assert list(chunk_text(text, max_tokens=20)) == [text]
def test_chunk_text_long_text():
"""Test chunking with long text that exceeds token limit"""
# Create a long text that will need multiple chunks
sentences = [f"This is sentence {i:02}." for i in range(50)]
text = " ".join(sentences)
max_tokens = 10 # 10 tokens = ~40 chars
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
] + [
'This is sentence 49.'
]
def test_chunk_text_with_overlap():
"""Test chunking with overlap between chunks"""
# Create text with distinct parts to test overlap
text = "Part A. Part B. Part C. Part D. Part E."
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
def test_chunk_text_zero_overlap():
"""Test chunking with zero overlap"""
text = "Part A. Part B. Part C. Part D. Part E."
# 2 tokens = ~8 chars
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
def test_chunk_text_clean_break():
"""Test that chunking attempts to break at sentence boundaries"""
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
max_tokens = 5 # Enough for about 2 sentences
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
def test_chunk_text_very_long_sentences():
"""Test with very long sentences that exceed the token limit"""
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
max_tokens = 5 # Small limit to force splitting
assert list(chunk_text(text, max_tokens=max_tokens)) == [
'This is a very long sentence with many many',
'words that will definitely exceed the',
'token limit we set for',
'this particular test',
'case and should be split into multiple',
'chunks by the function.',
]

View File

@ -5,12 +5,11 @@ import qdrant_client
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse
from memory.workers.qdrant import (
from memory.common.qdrant import (
DEFAULT_COLLECTIONS,
ensure_collection_exists,
initialize_collections,
upsert_vectors,
search_vectors,
delete_vectors,
)