mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 15:14:45 +02:00
add chunker
This commit is contained in:
parent
3dca666d08
commit
2d2f37536a
@ -1,5 +1,13 @@
|
||||
import pathlib
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, TypedDict, Iterable
|
||||
import voyageai
|
||||
import re
|
||||
|
||||
# Chunking configuration
|
||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
||||
OVERLAP_TOKENS = 200 # Default overlap between chunks
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||
Vector = list[float]
|
||||
@ -12,13 +20,13 @@ class Collection(TypedDict):
|
||||
|
||||
|
||||
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
||||
"mail": {"dimension": 1536, "distance": "Cosine"},
|
||||
"chat": {"dimension": 1536, "distance": "Cosine"},
|
||||
"git": {"dimension": 1536, "distance": "Cosine"},
|
||||
"mail": {"dimension": 1024, "distance": "Cosine"},
|
||||
"chat": {"dimension": 1024, "distance": "Cosine"},
|
||||
"git": {"dimension": 1024, "distance": "Cosine"},
|
||||
"photo": {"dimension": 512, "distance": "Cosine"},
|
||||
"book": {"dimension": 1536, "distance": "Cosine"},
|
||||
"blog": {"dimension": 1536, "distance": "Cosine"},
|
||||
"doc": {"dimension": 1536, "distance": "Cosine"},
|
||||
"book": {"dimension": 1024, "distance": "Cosine"},
|
||||
"blog": {"dimension": 1024, "distance": "Cosine"},
|
||||
"doc": {"dimension": 1024, "distance": "Cosine"},
|
||||
}
|
||||
|
||||
TYPES = {
|
||||
@ -40,21 +48,138 @@ def get_type(mime_type: str) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
|
||||
# Regex for sentence splitting
|
||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
words = text.split()
|
||||
if not words:
|
||||
return
|
||||
|
||||
current = ""
|
||||
for word in words:
|
||||
new_chunk = f"{current} {word}".strip()
|
||||
if current and approx_token_count(new_chunk) > max_tokens:
|
||||
yield current
|
||||
current = word
|
||||
else:
|
||||
current = new_chunk
|
||||
if current: # Only yield non-empty final chunk
|
||||
yield current
|
||||
|
||||
|
||||
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
"""
|
||||
Embed a text using OpenAI's API.
|
||||
Yield text spans in priority order: paragraphs, sentences, words.
|
||||
Each span is guaranteed to be under max_tokens.
|
||||
|
||||
Args:
|
||||
text: The text to split
|
||||
max_tokens: Maximum tokens per chunk
|
||||
|
||||
Yields:
|
||||
Spans of text that fit within token limits
|
||||
"""
|
||||
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
||||
# Return early for empty text
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
for paragraph in text.split("\n\n"):
|
||||
if not paragraph.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(paragraph) <= max_tokens:
|
||||
yield paragraph
|
||||
continue
|
||||
|
||||
for sentence in _SENT_SPLIT_RE.split(paragraph):
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(sentence) <= max_tokens:
|
||||
yield sentence
|
||||
continue
|
||||
|
||||
for chunk in yield_word_chunks(sentence, max_tokens):
|
||||
yield chunk
|
||||
|
||||
|
||||
def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
|
||||
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> list[str]:
|
||||
"""
|
||||
Embed a file using OpenAI's API.
|
||||
Split text into chunks respecting semantic boundaries while staying within token limits.
|
||||
|
||||
Args:
|
||||
text: The text to chunk
|
||||
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
|
||||
overlap: Number of tokens to overlap between chunks (default: 200)
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
if approx_token_count(text) <= max_tokens:
|
||||
yield text
|
||||
return
|
||||
|
||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
||||
current = ""
|
||||
|
||||
for span in yield_spans(text, max_tokens):
|
||||
current = f"{current} {span}".strip()
|
||||
if approx_token_count(current) < max_tokens:
|
||||
continue
|
||||
|
||||
if overlap <= 0:
|
||||
yield current
|
||||
current = ""
|
||||
continue
|
||||
|
||||
overlap_text = current[-overlap_chars:]
|
||||
clean_break = max(
|
||||
overlap_text.rfind(". "),
|
||||
overlap_text.rfind("! "),
|
||||
overlap_text.rfind("? ")
|
||||
)
|
||||
|
||||
print(f"clean_break: {clean_break}")
|
||||
print(f"overlap_text: {overlap_text}")
|
||||
print(f"current: {current}")
|
||||
|
||||
if clean_break < 0:
|
||||
yield current
|
||||
current = ""
|
||||
continue
|
||||
|
||||
break_offset = -overlap_chars + clean_break + 1
|
||||
chunk = current[break_offset:].strip()
|
||||
print(f"chunk: {chunk}")
|
||||
print(f"current: {current}")
|
||||
yield current
|
||||
current = chunk
|
||||
if current:
|
||||
yield current.strip()
|
||||
|
||||
|
||||
def embed(mime_type: str, content: bytes | str | pathlib.Path, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> tuple[str, list[float]]:
|
||||
collection = get_type(mime_type)
|
||||
def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
|
||||
vo = voyageai.Client()
|
||||
chunks = chunk_text(text, MAX_TOKENS)
|
||||
return [vo.embed(chunk, model=model) for chunk in chunks]
|
||||
|
||||
return collection, [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
||||
|
||||
def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
|
||||
return embed_text(file_path.read_text(), model, n_dimensions)
|
||||
|
||||
|
||||
def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> tuple[str, list[Vector]]:
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return get_type(mime_type), embed_text(content, model, n_dimensions)
|
||||
|
@ -12,9 +12,9 @@ from typing import Generator, Callable, TypedDict, Literal
|
||||
import pathlib
|
||||
from sqlalchemy.orm import Session
|
||||
from collections import defaultdict
|
||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||
from memory.common import settings, embedding
|
||||
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
|
||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||
from memory.common.qdrant import get_qdrant_client, upsert_vectors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.orm import sessionmaker
|
||||
from testcontainers.qdrant import QdrantContainer
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.qdrant import initialize_collections
|
||||
from tests.providers.email_provider import MockEmailProvider
|
||||
from memory.workers.qdrant import initialize_collections
|
||||
|
||||
|
||||
def get_test_db_name() -> str:
|
||||
|
272
tests/memory/common/test_embedding.py
Normal file
272
tests/memory/common/test_embedding.py
Normal file
@ -0,0 +1,272 @@
|
||||
import pytest
|
||||
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, MAX_TOKENS, approx_token_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []),
|
||||
("hello", ["hello"]),
|
||||
("This is a simple sentence", ["This is a simple sentence"]),
|
||||
("word1 word2", ["word1 word2"]),
|
||||
(" ", []), # Just spaces
|
||||
("\n\t ", []), # Whitespace characters
|
||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||
assert list(yield_word_chunks(text)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
(
|
||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
||||
),
|
||||
(
|
||||
"supercalifragilisticexpialidocious",
|
||||
["supercalifragilisticexpialidocious"],
|
||||
)
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_long_text(text, expected):
|
||||
"""Test chunking with long text that exceeds token limits"""
|
||||
assert list(yield_word_chunks(text, max_tokens=10)) == expected
|
||||
|
||||
|
||||
def test_yield_word_chunk_single_long_word():
|
||||
"""Test behavior with a single word longer than the token limit"""
|
||||
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
|
||||
|
||||
chunks = list(yield_word_chunks(long_word, max_tokens))
|
||||
# With our changes, this should be a single chunk
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == long_word
|
||||
|
||||
|
||||
def test_yield_word_chunk_small_token_limit():
|
||||
"""Test with a very small max_tokens value to force chunking"""
|
||||
text = "one two three four five"
|
||||
max_tokens = 1 # Very small to force chunking after each word
|
||||
|
||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, max_tokens, expected_chunks",
|
||||
[
|
||||
# Empty text
|
||||
("", 10, []),
|
||||
# Text below token limit
|
||||
("hello world", 10, ["hello world"]),
|
||||
# Text right at token limit
|
||||
(
|
||||
"word1 word2", # 11 chars with space
|
||||
3, # 12 chars limit
|
||||
["word1 word2"]
|
||||
),
|
||||
# Text just over token limit should split
|
||||
(
|
||||
"word1 word2 word3", # 17 chars with spaces
|
||||
4, # 16 chars limit
|
||||
["word1 word2 word3"]
|
||||
),
|
||||
# Each word exactly at token limit
|
||||
(
|
||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||
1, # 1 token limit (4 chars)
|
||||
["aaaa", "bbbb", "cccc"]
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||
"""Test different combinations of text and token limits"""
|
||||
assert list(yield_word_chunks(text, max_tokens)) == expected_chunks
|
||||
|
||||
|
||||
def test_yield_word_chunk_real_world_example():
|
||||
"""Test with a realistic text example"""
|
||||
text = (
|
||||
"The yield_word_chunks function splits text into chunks based on word boundaries. "
|
||||
"It tries to maximize chunk size while staying under the specified token limit. "
|
||||
"This behavior is essential for processing large documents efficiently."
|
||||
)
|
||||
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
'The yield_word_chunks function splits text',
|
||||
'into chunks based on word boundaries. It',
|
||||
'tries to maximize chunk size while staying',
|
||||
'under the specified token limit. This',
|
||||
'behavior is essential for processing large',
|
||||
'documents efficiently.',
|
||||
]
|
||||
|
||||
|
||||
# Tests for yield_spans function
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text should yield nothing
|
||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
)
|
||||
def test_yield_spans_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_spans with various inputs"""
|
||||
assert list(yield_spans(text)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_paragraphs():
|
||||
"""Test splitting by paragraphs"""
|
||||
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
|
||||
expected = ["Paragraph one.", "Paragraph two.", "Paragraph three."]
|
||||
assert list(yield_spans(text)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_sentences():
|
||||
"""Test splitting by sentences when paragraphs exceed token limit"""
|
||||
# Create a paragraph that exceeds token limit but sentences are within limit
|
||||
max_tokens = 5 # 20 chars with CHARS_PER_TOKEN = 4
|
||||
sentence1 = "Short sentence one." # ~20 chars
|
||||
sentence2 = "Another short sentence." # ~24 chars
|
||||
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
|
||||
|
||||
# Function should now preserve punctuation
|
||||
expected = ["Short sentence one.", "Another short sentence."]
|
||||
assert list(yield_spans(text, max_tokens)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_words():
|
||||
"""Test splitting by words when sentences exceed token limit"""
|
||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
||||
|
||||
|
||||
def test_yield_spans_complex_document():
|
||||
"""Test with a document containing multiple paragraphs and sentences"""
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
text = (
|
||||
"Paragraph one with a short sentence. And another sentence that should be split.\n\n"
|
||||
"Paragraph two is also here. It has multiple sentences. Some are short. "
|
||||
"This one is longer and might need word splitting depending on the limit.\n\n"
|
||||
"Final short paragraph."
|
||||
)
|
||||
|
||||
assert list(yield_spans(text, max_tokens)) == [
|
||||
"Paragraph one with a short sentence.",
|
||||
"And another sentence that should be split.",
|
||||
"Paragraph two is also here.",
|
||||
"It has multiple sentences.",
|
||||
"Some are short.",
|
||||
"This one is longer and might need word",
|
||||
"splitting depending on the limit.",
|
||||
"Final short paragraph."
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_very_long_word():
|
||||
"""Test with a word that exceeds the token limit"""
|
||||
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
|
||||
|
||||
assert list(yield_spans(long_word, max_tokens)) == [long_word]
|
||||
|
||||
|
||||
def test_yield_spans_with_punctuation():
|
||||
"""Test sentence splitting with various punctuation"""
|
||||
text = "First sentence! Second sentence? Third sentence."
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
||||
|
||||
|
||||
def test_yield_spans_edge_cases():
|
||||
"""Test edge cases like empty paragraphs, single character paragraphs"""
|
||||
text = "\n\nA\n\n\n\nB\n\n"
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text
|
||||
("Short text", ["Short text"]), # Text below token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
)
|
||||
def test_chunk_text_basic_behavior(text, expected):
|
||||
"""Test basic behavior of chunk_text with various inputs"""
|
||||
assert list(chunk_text(text)) == expected
|
||||
|
||||
|
||||
def test_chunk_text_single_paragraph():
|
||||
"""Test chunking a single paragraph that fits within token limit"""
|
||||
text = "This is a simple paragraph that should fit in one chunk."
|
||||
assert list(chunk_text(text, max_tokens=20)) == [text]
|
||||
|
||||
|
||||
def test_chunk_text_multi_paragraph():
|
||||
"""Test chunking multiple paragraphs"""
|
||||
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
|
||||
assert list(chunk_text(text, max_tokens=20)) == [text]
|
||||
|
||||
|
||||
def test_chunk_text_long_text():
|
||||
"""Test chunking with long text that exceeds token limit"""
|
||||
# Create a long text that will need multiple chunks
|
||||
sentences = [f"This is sentence {i:02}." for i in range(50)]
|
||||
text = " ".join(sentences)
|
||||
|
||||
max_tokens = 10 # 10 tokens = ~40 chars
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
||||
] + [
|
||||
'This is sentence 49.'
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_with_overlap():
|
||||
"""Test chunking with overlap between chunks"""
|
||||
# Create text with distinct parts to test overlap
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
||||
|
||||
|
||||
def test_chunk_text_zero_overlap():
|
||||
"""Test chunking with zero overlap"""
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
# 2 tokens = ~8 chars
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
||||
|
||||
|
||||
def test_chunk_text_clean_break():
|
||||
"""Test that chunking attempts to break at sentence boundaries"""
|
||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
max_tokens = 5 # Enough for about 2 sentences
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
||||
|
||||
|
||||
def test_chunk_text_very_long_sentences():
|
||||
"""Test with very long sentences that exceed the token limit"""
|
||||
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
||||
|
||||
max_tokens = 5 # Small limit to force splitting
|
||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||
'This is a very long sentence with many many',
|
||||
'words that will definitely exceed the',
|
||||
'token limit we set for',
|
||||
'this particular test',
|
||||
'case and should be split into multiple',
|
||||
'chunks by the function.',
|
||||
]
|
||||
|
@ -5,12 +5,11 @@ import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
from memory.workers.qdrant import (
|
||||
from memory.common.qdrant import (
|
||||
DEFAULT_COLLECTIONS,
|
||||
ensure_collection_exists,
|
||||
initialize_collections,
|
||||
upsert_vectors,
|
||||
search_vectors,
|
||||
delete_vectors,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user