second pass in search

This commit is contained in:
Daniel O'Connell 2025-06-28 20:59:15 +02:00
parent 06eec621c1
commit 8eb6374cac
11 changed files with 445 additions and 142 deletions

View File

@ -9,6 +9,7 @@ const SearchResults = ({ results, isLoading }: { results: any[], isLoading: bool
if (isLoading) { if (isLoading) {
return <Loading message="Searching..." /> return <Loading message="Searching..." />
} }
console.log("results",results)
return ( return (
<div className="search-results"> <div className="search-results">
{results.length > 0 && ( {results.length > 0 && (

View File

@ -0,0 +1,65 @@
import asyncio
from bs4 import BeautifulSoup
from PIL import Image
from memory.common.db.models.source_item import Chunk
from memory.common import llms, settings, tokens
SCORE_CHUNK_SYSTEM_PROMPT = """
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
"""
SCORE_CHUNK_PROMPT = """
Here is the query:
<query>{query}</query>
Here is the chunk:
<chunk>
{chunk}
</chunk>
Please return your score as a number between 0 and 1 formatted as:
<score>your score</score>
Please always return a summary of any images provided.
"""
async def score_chunk(query: str, chunk: Chunk) -> Chunk:
data = chunk.data
chunk_text = "\n".join(text for text in data if isinstance(text, str))
images = [image for image in data if isinstance(image, Image.Image)]
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
response = await asyncio.to_thread(
llms.call,
prompt,
settings.RANKER_MODEL,
images=images,
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
)
soup = BeautifulSoup(response, "html.parser")
if not (score := soup.find("score")):
chunk.relevance_score = 0.0
else:
try:
chunk.relevance_score = float(score.text.strip())
except ValueError:
chunk.relevance_score = 0.0
return chunk
async def rank_chunks(
query: str, chunks: list[Chunk], min_score: float = 0
) -> list[Chunk]:
calls = [score_chunk(query, chunk) for chunk in chunks]
scored = await asyncio.gather(*calls)
return sorted(
[chunk for chunk in scored if chunk.relevance_score >= min_score],
key=lambda x: x.relevance_score or 0,
reverse=True,
)

View File

@ -12,6 +12,7 @@ from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem from memory.common.db.models import Chunk, SourceItem
from memory.common.collections import ALL_COLLECTIONS from memory.common.collections import ALL_COLLECTIONS
from memory.api.search.embeddings import search_chunks_embeddings from memory.api.search.embeddings import search_chunks_embeddings
from memory.api.search import scorer
if settings.ENABLE_BM25_SEARCH: if settings.ENABLE_BM25_SEARCH:
from memory.api.search.bm25 import search_bm25_chunks from memory.api.search.bm25 import search_bm25_chunks
@ -40,7 +41,14 @@ async def search_chunks(
with make_session() as db: with make_session() as db:
chunks = ( chunks = (
db.query(Chunk) db.query(Chunk)
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore .options(
load_only(
Chunk.id, # type: ignore
Chunk.source_id, # type: ignore
Chunk.content, # type: ignore
Chunk.file_paths, # type: ignore
)
)
.filter(Chunk.id.in_(all_ids)) .filter(Chunk.id.in_(all_ids))
.all() .all()
) )
@ -91,4 +99,6 @@ async def search(
filters, filters,
timeout, timeout,
) )
if settings.ENABLE_SEARCH_SCORING:
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
return await search_sources(chunks, previews) return await search_sources(chunks, previews)

View File

@ -2,7 +2,7 @@ import logging
from typing import Iterable, Any from typing import Iterable, Any
import re import re
from memory.common import settings from memory.common import settings, tokens
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
OVERLAP_TOKENS = settings.OVERLAP_TOKENS OVERLAP_TOKENS = settings.OVERLAP_TOKENS
CHARS_PER_TOKEN = 4
Vector = list[float] Vector = list[float]
@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+") _SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN
def yield_word_chunks( def yield_word_chunks(
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
) -> Iterable[str]: ) -> Iterable[str]:
@ -36,7 +31,7 @@ def yield_word_chunks(
current = "" current = ""
for word in words: for word in words:
new_chunk = f"{current} {word}".strip() new_chunk = f"{current} {word}".strip()
if current and approx_token_count(new_chunk) > max_tokens: if current and tokens.approx_token_count(new_chunk) > max_tokens:
yield current yield current
current = word current = word
else: else:
@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
if not paragraph.strip(): if not paragraph.strip():
continue continue
if approx_token_count(paragraph) <= max_tokens: if tokens.approx_token_count(paragraph) <= max_tokens:
yield paragraph yield paragraph
continue continue
@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
if not sentence.strip(): if not sentence.strip():
continue continue
if approx_token_count(sentence) <= max_tokens: if tokens.approx_token_count(sentence) <= max_tokens:
yield sentence yield sentence
continue continue
@ -99,16 +94,16 @@ def chunk_text(
if not text: if not text:
return return
if approx_token_count(text) <= max_tokens: if tokens.approx_token_count(text) <= max_tokens:
yield text yield text
return return
overlap_chars = overlap * CHARS_PER_TOKEN overlap_chars = overlap * tokens.CHARS_PER_TOKEN
current = "" current = ""
for span in yield_spans(text, max_tokens): for span in yield_spans(text, max_tokens):
current = f"{current} {span}".strip() current = f"{current} {span}".strip()
if approx_token_count(current) < max_tokens: if tokens.approx_token_count(current) < max_tokens:
continue continue
if overlap <= 0: if overlap <= 0:

View File

@ -156,9 +156,11 @@ class Chunk(Base):
collection_name = Column(Text) collection_name = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
checked_at = Column(DateTime(timezone=True), server_default=func.now()) checked_at = Column(DateTime(timezone=True), server_default=func.now())
vector: list[float] = [] vector: list[float] = []
item_metadata: dict[str, Any] = {} item_metadata: dict[str, Any] = {}
images: list[Image.Image] = [] images: list[Image.Image] = []
relevance_score: float = 0.0
# One of file_path or content must be populated # One of file_path or content must be populated
__table_args__ = ( __table_args__ = (

122
src/memory/common/llms.py Normal file
View File

@ -0,0 +1,122 @@
import logging
import base64
import io
from typing import Any
from PIL import Image
from memory.common import settings, tokens
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """
You are a helpful assistant that creates concise summaries and identifies key topics.
"""
def encode_image(image: Image.Image) -> str:
"""Encode PIL Image to base64 string."""
buffer = io.BytesIO()
# Convert to RGB if necessary (for RGBA, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def call_openai(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
"""Call OpenAI API for summarization."""
import openai
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
try:
user_content: Any = [{"type": "text", "text": prompt}]
if images:
for image in images:
encoded_image = encode_image(image)
user_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
}
)
response = client.chat.completions.create(
model=model.split("/")[1],
messages=[
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": user_content},
],
temperature=0.3,
max_tokens=2048,
)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def call_anthropic(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
"""Call Anthropic API for summarization."""
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
try:
# Prepare the message content
content: Any = [{"type": "text", "text": prompt}]
if images:
# Add images if provided
for image in images:
encoded_image = encode_image(image)
content.append(
{ # type: ignore
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": encoded_image,
},
}
)
response = client.messages.create(
model=model.split("/")[1],
messages=[{"role": "user", "content": content}], # type: ignore
system=system_prompt,
temperature=0.3,
max_tokens=2048,
)
return response.content[0].text
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def call(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
if model.startswith("anthropic"):
return call_anthropic(prompt, model, images, system_prompt)
return call_openai(prompt, model, images, system_prompt)
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content

View File

@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
else: else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307") SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
# Search settings # Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True) ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True) ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16)) MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000)) MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))

View File

@ -4,7 +4,7 @@ from typing import Any
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from memory.common import settings, chunker from memory.common import settings, tokens, llms
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
return {"summary": summary, "tags": tags} return {"summary": summary, "tags": tags}
def _call_openai(prompt: str) -> dict[str, Any]:
"""Call OpenAI API for summarization."""
import openai
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
try:
response = client.chat.completions.create(
model=settings.SUMMARIZER_MODEL.split("/")[1],
messages=[
{
"role": "system",
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
},
{"role": "user", "content": prompt},
],
temperature=0.3,
max_tokens=2048,
)
return parse_response(response.choices[0].message.content or "")
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def _call_anthropic(prompt: str) -> dict[str, Any]:
"""Call Anthropic API for summarization."""
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
try:
response = client.messages.create(
model=settings.SUMMARIZER_MODEL.split("/")[1],
messages=[{"role": "user", "content": prompt}],
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
temperature=0.3,
max_tokens=2048,
)
return parse_response(response.content[0].text)
except Exception as e:
logger.error(f"Anthropic API error: {e}")
logger.error(response.content[0].text)
raise
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]: def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
""" """
Summarize content to approximately target_tokens length and generate tags. Summarize content to approximately target_tokens length and generate tags.
@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
summary, tags = content, [] summary, tags = content, []
# If content is already short enough, just extract tags # If content is already short enough, just extract tags
current_tokens = chunker.approx_token_count(content) current_tokens = tokens.approx_token_count(content)
if current_tokens <= target_tokens: if current_tokens <= target_tokens:
logger.info( logger.info(
f"Content already under {target_tokens} tokens, extracting tags only" f"Content already under {target_tokens} tokens, extracting tags only"
@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
else: else:
prompt = SUMMARY_PROMPT.format( prompt = SUMMARY_PROMPT.format(
target_tokens=target_tokens, target_tokens=target_tokens,
target_chars=target_tokens * chunker.CHARS_PER_TOKEN, target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
content=content, content=content,
) )
if chunker.approx_token_count(prompt) > MAX_TOKENS: if tokens.approx_token_count(prompt) > MAX_TOKENS:
logger.warning( logger.warning(
f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating" f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
) )
prompt = truncate(prompt, MAX_TOKENS - 20) prompt = llms.truncate(prompt, MAX_TOKENS - 20)
try: try:
if settings.SUMMARIZER_MODEL.startswith("anthropic"): response = llms.call(prompt, settings.SUMMARIZER_MODEL)
result = _call_anthropic(prompt) result = parse_response(response)
else:
result = _call_openai(prompt)
summary = result.get("summary", "") summary = result.get("summary", "")
tags = result.get("tags", []) tags = result.get("tags", [])
@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
traceback.print_exc() traceback.print_exc()
logger.error(f"Summarization failed: {e}") logger.error(f"Summarization failed: {e}")
tokens = chunker.approx_token_count(summary) summary_tokens = tokens.approx_token_count(summary)
if tokens > target_tokens * 1.5: if summary_tokens > target_tokens * 1.5:
logger.warning(f"Summary too long ({tokens} tokens), truncating") logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
summary = truncate(content, target_tokens) summary = llms.truncate(content, target_tokens)
return summary, tags return summary, tags

128
src/memory/common/tokens.py Normal file
View File

@ -0,0 +1,128 @@
import logging
from PIL import Image
import math
logger = logging.getLogger(__name__)
CHARS_PER_TOKEN = 4
def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN
def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
"""
Estimate tokens for an image using OpenAI's counting method.
Args:
image: PIL Image
detail: "high" or "low" detail level
Returns:
Estimated token count
"""
if detail == "low":
return 85
# For high detail, OpenAI resizes the image to fit within 2048x2048
# while maintaining aspect ratio, then counts 512x512 tiles
width, height = image.size
# Resize logic to fit within 2048x2048
if width > 2048 or height > 2048:
if width > height:
height = int(height * 2048 / width)
width = 2048
else:
width = int(width * 2048 / height)
height = 2048
# Further resize so shortest side is 768px
if width < height:
if width > 768:
height = int(height * 768 / width)
width = 768
else:
if height > 768:
width = int(width * 768 / height)
height = 768
# Count 512x512 tiles
tiles_width = math.ceil(width / 512)
tiles_height = math.ceil(height / 512)
total_tiles = tiles_width * tiles_height
# Each tile costs 170 tokens, plus 85 base tokens
return total_tiles * 170 + 85
def estimate_anthropic_image_tokens(image: Image.Image) -> int:
"""
Estimate tokens for an image using Anthropic's counting method.
Args:
image: PIL Image
Returns:
Estimated token count
"""
width, height = image.size
# Anthropic's token counting is based on image dimensions
# They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
# But they also have a base cost per image
# Rough approximation based on Anthropic's documentation
# They count tokens based on the image size after potential resizing
total_pixels = width * height
# Anthropic typically charges around 1.15 tokens per 1000 pixels
# with a minimum base cost
base_tokens = 100 # Base cost for any image
pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
return base_tokens + pixel_tokens
def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
"""
Estimate tokens for an image based on the model provider.
Args:
image: PIL Image
model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
detail: Detail level for OpenAI models ("high" or "low")
Returns:
Estimated token count
"""
if model.startswith("anthropic"):
return estimate_anthropic_image_tokens(image)
else:
return estimate_openai_image_tokens(image, detail)
def estimate_total_tokens(
prompt: str, images: list[Image.Image], model: str, detail: str = "high"
) -> int:
"""
Estimate total tokens for a prompt with images.
Args:
prompt: Text prompt
images: List of PIL Images
model: Model string
detail: Detail level for OpenAI models
Returns:
Estimated total token count
"""
# Estimate text tokens
text_tokens = approx_token_count(prompt)
# Estimate image tokens
image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
return text_tokens + image_tokens

View File

@ -1,5 +1,6 @@
import pytest import pytest
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
from memory.common.tokens import CHARS_PER_TOKEN
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
(" ", []), # Just spaces (" ", []), # Just spaces
("\n\t ", []), # Whitespace characters ("\n\t ", []), # Whitespace characters
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace ("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
] ],
) )
def test_yield_word_chunk_basic_behavior(text, expected): def test_yield_word_chunk_basic_behavior(text, expected):
"""Test basic behavior of yield_word_chunks with various inputs""" """Test basic behavior of yield_word_chunks with various inputs"""
@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
[ [
( (
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5", "word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'], ["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
), ),
( (
"supercalifragilisticexpialidocious", "supercalifragilisticexpialidocious",
["supercalifragilisticexpialidocious"], ["supercalifragilisticexpialidocious"],
) ),
] ],
) )
def test_yield_word_chunk_long_text(text, expected): def test_yield_word_chunk_long_text(text, expected):
"""Test chunking with long text that exceeds token limits""" """Test chunking with long text that exceeds token limits"""
@ -53,7 +54,12 @@ def test_yield_word_chunk_small_token_limit():
text = "one two three four five" text = "one two three four five"
max_tokens = 1 # Very small to force chunking after each word max_tokens = 1 # Very small to force chunking after each word
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"] assert list(yield_word_chunks(text, max_tokens)) == [
"one two",
"three",
"four",
"five",
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
( (
"word1 word2", # 11 chars with space "word1 word2", # 11 chars with space
3, # 12 chars limit 3, # 12 chars limit
["word1 word2"] ["word1 word2"],
), ),
# Text just over token limit should split # Text just over token limit should split
( (
"word1 word2 word3", # 17 chars with spaces "word1 word2 word3", # 17 chars with spaces
4, # 16 chars limit 4, # 16 chars limit
["word1 word2 word3"] ["word1 word2 word3"],
), ),
# Each word exactly at token limit # Each word exactly at token limit
( (
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token) "aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
1, # 1 token limit (4 chars) 1, # 1 token limit (4 chars)
["aaaa", "bbbb", "cccc"] ["aaaa", "bbbb", "cccc"],
), ),
] ],
) )
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks): def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
"""Test different combinations of text and token limits""" """Test different combinations of text and token limits"""
@ -98,12 +104,12 @@ def test_yield_word_chunk_real_world_example():
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4 max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
assert list(yield_word_chunks(text, max_tokens)) == [ assert list(yield_word_chunks(text, max_tokens)) == [
'The yield_word_chunks function splits text', "The yield_word_chunks function splits text",
'into chunks based on word boundaries. It', "into chunks based on word boundaries. It",
'tries to maximize chunk size while staying', "tries to maximize chunk size while staying",
'under the specified token limit. This', "under the specified token limit. This",
'behavior is essential for processing large', "behavior is essential for processing large",
'documents efficiently.', "documents efficiently.",
] ]
@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
"text, expected", "text, expected",
[ [
("", []), # Empty text should yield nothing ("", []), # Empty text should yield nothing
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit (
"Simple paragraph",
["Simple paragraph"],
), # Single paragraph under token limit
(" ", []), # Just whitespace (" ", []), # Just whitespace
] ],
) )
def test_yield_spans_basic_behavior(text, expected): def test_yield_spans_basic_behavior(text, expected):
"""Test basic behavior of yield_spans with various inputs""" """Test basic behavior of yield_spans with various inputs"""
@ -146,7 +155,13 @@ def test_yield_spans_words():
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4 max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
long_sentence = "This sentence has several words and needs word-level chunking." long_sentence = "This sentence has several words and needs word-level chunking."
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.'] assert list(yield_spans(long_sentence, max_tokens)) == [
"This sentence",
"has several",
"words and needs",
"word-level",
"chunking.",
]
def test_yield_spans_complex_document(): def test_yield_spans_complex_document():
@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
"Some are short.", "Some are short.",
"This one is longer and might need word", "This one is longer and might need word",
"splitting depending on the limit.", "splitting depending on the limit.",
"Final short paragraph." "Final short paragraph.",
] ]
@ -183,7 +198,11 @@ def test_yield_spans_with_punctuation():
"""Test sentence splitting with various punctuation""" """Test sentence splitting with various punctuation"""
text = "First sentence! Second sentence? Third sentence." text = "First sentence! Second sentence? Third sentence."
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."] assert list(yield_spans(text, max_tokens=10)) == [
"First sentence!",
"Second sentence?",
"Third sentence.",
]
def test_yield_spans_edge_cases(): def test_yield_spans_edge_cases():
@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
("", []), # Empty text ("", []), # Empty text
("Short text", ["Short text"]), # Text below token limit ("Short text", ["Short text"]), # Text below token limit
(" ", []), # Just whitespace (" ", []), # Just whitespace
] ],
) )
def test_chunk_text_basic_behavior(text, expected): def test_chunk_text_basic_behavior(text, expected):
"""Test basic behavior of chunk_text with various inputs""" """Test basic behavior of chunk_text with various inputs"""
@ -226,10 +245,8 @@ def test_chunk_text_long_text():
max_tokens = 10 # 10 tokens = ~40 chars max_tokens = 10 # 10 tokens = ~40 chars
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [ assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49) f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
] + [ ] + ["This is sentence 49."]
'This is sentence 49.'
]
def test_chunk_text_with_overlap(): def test_chunk_text_with_overlap():
@ -237,7 +254,11 @@ def test_chunk_text_with_overlap():
# Create text with distinct parts to test overlap # Create text with distinct parts to test overlap
text = "Part A. Part B. Part C. Part D. Part E." text = "Part A. Part B. Part C. Part D. Part E."
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.'] assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
"Part A. Part B. Part C.",
"Part C. Part D. Part E.",
"Part E.",
]
def test_chunk_text_zero_overlap(): def test_chunk_text_zero_overlap():
@ -245,7 +266,11 @@ def test_chunk_text_zero_overlap():
text = "Part A. Part B. Part C. Part D. Part E." text = "Part A. Part B. Part C. Part D. Part E."
# 2 tokens = ~8 chars # 2 tokens = ~8 chars
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.'] assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
"Part A. Part B.",
"Part C. Part D.",
"Part E.",
]
def test_chunk_text_clean_break(): def test_chunk_text_clean_break():
@ -253,7 +278,10 @@ def test_chunk_text_clean_break():
text = "First sentence. Second sentence. Third sentence. Fourth sentence." text = "First sentence. Second sentence. Third sentence. Fourth sentence."
max_tokens = 5 # Enough for about 2 sentences max_tokens = 5 # Enough for about 2 sentences
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.'] assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
"First sentence. Second sentence.",
"Third sentence. Fourth sentence.",
]
def test_chunk_text_very_long_sentences(): def test_chunk_text_very_long_sentences():
@ -262,24 +290,10 @@ def test_chunk_text_very_long_sentences():
max_tokens = 5 # Small limit to force splitting max_tokens = 5 # Small limit to force splitting
assert list(chunk_text(text, max_tokens=max_tokens)) == [ assert list(chunk_text(text, max_tokens=max_tokens)) == [
'This is a very long sentence with many many', "This is a very long sentence with many many",
'words that will definitely exceed the', "words that will definitely exceed the",
'token limit we set for', "token limit we set for",
'this particular test', "this particular test",
'case and should be split into multiple', "case and should be split into multiple",
'chunks by the function.', "chunks by the function.",
] ]
@pytest.mark.parametrize(
"string, expected_count",
[
("", 0),
("a" * CHARS_PER_TOKEN, 1),
("a" * (CHARS_PER_TOKEN * 2), 2),
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
]
)
def test_approx_token_count(string, expected_count):
assert approx_token_count(string) == expected_count

View File

@ -0,0 +1,16 @@
import pytest
from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
@pytest.mark.parametrize(
"string, expected_count",
[
("", 0),
("a" * CHARS_PER_TOKEN, 1),
("a" * (CHARS_PER_TOKEN * 2), 2),
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
],
)
def test_approx_token_count(string, expected_count):
assert approx_token_count(string) == expected_count