second pass in search

This commit is contained in:
Daniel O'Connell 2025-06-28 20:59:15 +02:00
parent 06eec621c1
commit 8eb6374cac
11 changed files with 445 additions and 142 deletions

View File

@ -9,6 +9,7 @@ const SearchResults = ({ results, isLoading }: { results: any[], isLoading: bool
if (isLoading) {
return <Loading message="Searching..." />
}
console.log("results",results)
return (
<div className="search-results">
{results.length > 0 && (

View File

@ -0,0 +1,65 @@
import asyncio
from bs4 import BeautifulSoup
from PIL import Image
from memory.common.db.models.source_item import Chunk
from memory.common import llms, settings, tokens
SCORE_CHUNK_SYSTEM_PROMPT = """
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
"""
SCORE_CHUNK_PROMPT = """
Here is the query:
<query>{query}</query>
Here is the chunk:
<chunk>
{chunk}
</chunk>
Please return your score as a number between 0 and 1 formatted as:
<score>your score</score>
Please always return a summary of any images provided.
"""
async def score_chunk(query: str, chunk: Chunk) -> Chunk:
data = chunk.data
chunk_text = "\n".join(text for text in data if isinstance(text, str))
images = [image for image in data if isinstance(image, Image.Image)]
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
response = await asyncio.to_thread(
llms.call,
prompt,
settings.RANKER_MODEL,
images=images,
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
)
soup = BeautifulSoup(response, "html.parser")
if not (score := soup.find("score")):
chunk.relevance_score = 0.0
else:
try:
chunk.relevance_score = float(score.text.strip())
except ValueError:
chunk.relevance_score = 0.0
return chunk
async def rank_chunks(
query: str, chunks: list[Chunk], min_score: float = 0
) -> list[Chunk]:
calls = [score_chunk(query, chunk) for chunk in chunks]
scored = await asyncio.gather(*calls)
return sorted(
[chunk for chunk in scored if chunk.relevance_score >= min_score],
key=lambda x: x.relevance_score or 0,
reverse=True,
)

View File

@ -12,6 +12,7 @@ from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.common.collections import ALL_COLLECTIONS
from memory.api.search.embeddings import search_chunks_embeddings
from memory.api.search import scorer
if settings.ENABLE_BM25_SEARCH:
from memory.api.search.bm25 import search_bm25_chunks
@ -40,7 +41,14 @@ async def search_chunks(
with make_session() as db:
chunks = (
db.query(Chunk)
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
.options(
load_only(
Chunk.id, # type: ignore
Chunk.source_id, # type: ignore
Chunk.content, # type: ignore
Chunk.file_paths, # type: ignore
)
)
.filter(Chunk.id.in_(all_ids))
.all()
)
@ -91,4 +99,6 @@ async def search(
filters,
timeout,
)
if settings.ENABLE_SEARCH_SCORING:
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
return await search_sources(chunks, previews)

View File

@ -2,7 +2,7 @@ import logging
from typing import Iterable, Any
import re
from memory.common import settings
from memory.common import settings, tokens
logger = logging.getLogger(__name__)
@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
CHARS_PER_TOKEN = 4
Vector = list[float]
@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN
def yield_word_chunks(
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
) -> Iterable[str]:
@ -36,7 +31,7 @@ def yield_word_chunks(
current = ""
for word in words:
new_chunk = f"{current} {word}".strip()
if current and approx_token_count(new_chunk) > max_tokens:
if current and tokens.approx_token_count(new_chunk) > max_tokens:
yield current
current = word
else:
@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
if not paragraph.strip():
continue
if approx_token_count(paragraph) <= max_tokens:
if tokens.approx_token_count(paragraph) <= max_tokens:
yield paragraph
continue
@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
if not sentence.strip():
continue
if approx_token_count(sentence) <= max_tokens:
if tokens.approx_token_count(sentence) <= max_tokens:
yield sentence
continue
@ -99,16 +94,16 @@ def chunk_text(
if not text:
return
if approx_token_count(text) <= max_tokens:
if tokens.approx_token_count(text) <= max_tokens:
yield text
return
overlap_chars = overlap * CHARS_PER_TOKEN
overlap_chars = overlap * tokens.CHARS_PER_TOKEN
current = ""
for span in yield_spans(text, max_tokens):
current = f"{current} {span}".strip()
if approx_token_count(current) < max_tokens:
if tokens.approx_token_count(current) < max_tokens:
continue
if overlap <= 0:

View File

@ -156,9 +156,11 @@ class Chunk(Base):
collection_name = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now())
checked_at = Column(DateTime(timezone=True), server_default=func.now())
vector: list[float] = []
item_metadata: dict[str, Any] = {}
images: list[Image.Image] = []
relevance_score: float = 0.0
# One of file_path or content must be populated
__table_args__ = (

122
src/memory/common/llms.py Normal file
View File

@ -0,0 +1,122 @@
import logging
import base64
import io
from typing import Any
from PIL import Image
from memory.common import settings, tokens
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """
You are a helpful assistant that creates concise summaries and identifies key topics.
"""
def encode_image(image: Image.Image) -> str:
"""Encode PIL Image to base64 string."""
buffer = io.BytesIO()
# Convert to RGB if necessary (for RGBA, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def call_openai(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
"""Call OpenAI API for summarization."""
import openai
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
try:
user_content: Any = [{"type": "text", "text": prompt}]
if images:
for image in images:
encoded_image = encode_image(image)
user_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
}
)
response = client.chat.completions.create(
model=model.split("/")[1],
messages=[
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": user_content},
],
temperature=0.3,
max_tokens=2048,
)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def call_anthropic(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
"""Call Anthropic API for summarization."""
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
try:
# Prepare the message content
content: Any = [{"type": "text", "text": prompt}]
if images:
# Add images if provided
for image in images:
encoded_image = encode_image(image)
content.append(
{ # type: ignore
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": encoded_image,
},
}
)
response = client.messages.create(
model=model.split("/")[1],
messages=[{"role": "user", "content": content}], # type: ignore
system=system_prompt,
temperature=0.3,
max_tokens=2048,
)
return response.content[0].text
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
def call(
prompt: str,
model: str,
images: list[Image.Image] = [],
system_prompt: str = SYSTEM_PROMPT,
) -> str:
if model.startswith("anthropic"):
return call_anthropic(prompt, model, images, system_prompt)
return call_openai(prompt, model, images, system_prompt)
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content

View File

@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
# Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))

View File

@ -4,7 +4,7 @@ from typing import Any
from bs4 import BeautifulSoup
from memory.common import settings, chunker
from memory.common import settings, tokens, llms
logger = logging.getLogger(__name__)
@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
return {"summary": summary, "tags": tags}
def _call_openai(prompt: str) -> dict[str, Any]:
"""Call OpenAI API for summarization."""
import openai
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
try:
response = client.chat.completions.create(
model=settings.SUMMARIZER_MODEL.split("/")[1],
messages=[
{
"role": "system",
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
},
{"role": "user", "content": prompt},
],
temperature=0.3,
max_tokens=2048,
)
return parse_response(response.choices[0].message.content or "")
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
def _call_anthropic(prompt: str) -> dict[str, Any]:
"""Call Anthropic API for summarization."""
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
try:
response = client.messages.create(
model=settings.SUMMARIZER_MODEL.split("/")[1],
messages=[{"role": "user", "content": prompt}],
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
temperature=0.3,
max_tokens=2048,
)
return parse_response(response.content[0].text)
except Exception as e:
logger.error(f"Anthropic API error: {e}")
logger.error(response.content[0].text)
raise
def truncate(content: str, target_tokens: int) -> str:
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
if len(content) > target_chars:
return content[:target_chars].rsplit(" ", 1)[0] + "..."
return content
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
"""
Summarize content to approximately target_tokens length and generate tags.
@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
summary, tags = content, []
# If content is already short enough, just extract tags
current_tokens = chunker.approx_token_count(content)
current_tokens = tokens.approx_token_count(content)
if current_tokens <= target_tokens:
logger.info(
f"Content already under {target_tokens} tokens, extracting tags only"
@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
else:
prompt = SUMMARY_PROMPT.format(
target_tokens=target_tokens,
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
content=content,
)
if chunker.approx_token_count(prompt) > MAX_TOKENS:
if tokens.approx_token_count(prompt) > MAX_TOKENS:
logger.warning(
f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating"
f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
)
prompt = truncate(prompt, MAX_TOKENS - 20)
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
try:
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
result = _call_anthropic(prompt)
else:
result = _call_openai(prompt)
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
result = parse_response(response)
summary = result.get("summary", "")
tags = result.get("tags", [])
@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
traceback.print_exc()
logger.error(f"Summarization failed: {e}")
tokens = chunker.approx_token_count(summary)
if tokens > target_tokens * 1.5:
logger.warning(f"Summary too long ({tokens} tokens), truncating")
summary = truncate(content, target_tokens)
summary_tokens = tokens.approx_token_count(summary)
if summary_tokens > target_tokens * 1.5:
logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
summary = llms.truncate(content, target_tokens)
return summary, tags

128
src/memory/common/tokens.py Normal file
View File

@ -0,0 +1,128 @@
import logging
from PIL import Image
import math
logger = logging.getLogger(__name__)
CHARS_PER_TOKEN = 4
def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN
def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
"""
Estimate tokens for an image using OpenAI's counting method.
Args:
image: PIL Image
detail: "high" or "low" detail level
Returns:
Estimated token count
"""
if detail == "low":
return 85
# For high detail, OpenAI resizes the image to fit within 2048x2048
# while maintaining aspect ratio, then counts 512x512 tiles
width, height = image.size
# Resize logic to fit within 2048x2048
if width > 2048 or height > 2048:
if width > height:
height = int(height * 2048 / width)
width = 2048
else:
width = int(width * 2048 / height)
height = 2048
# Further resize so shortest side is 768px
if width < height:
if width > 768:
height = int(height * 768 / width)
width = 768
else:
if height > 768:
width = int(width * 768 / height)
height = 768
# Count 512x512 tiles
tiles_width = math.ceil(width / 512)
tiles_height = math.ceil(height / 512)
total_tiles = tiles_width * tiles_height
# Each tile costs 170 tokens, plus 85 base tokens
return total_tiles * 170 + 85
def estimate_anthropic_image_tokens(image: Image.Image) -> int:
"""
Estimate tokens for an image using Anthropic's counting method.
Args:
image: PIL Image
Returns:
Estimated token count
"""
width, height = image.size
# Anthropic's token counting is based on image dimensions
# They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
# But they also have a base cost per image
# Rough approximation based on Anthropic's documentation
# They count tokens based on the image size after potential resizing
total_pixels = width * height
# Anthropic typically charges around 1.15 tokens per 1000 pixels
# with a minimum base cost
base_tokens = 100 # Base cost for any image
pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
return base_tokens + pixel_tokens
def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
"""
Estimate tokens for an image based on the model provider.
Args:
image: PIL Image
model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
detail: Detail level for OpenAI models ("high" or "low")
Returns:
Estimated token count
"""
if model.startswith("anthropic"):
return estimate_anthropic_image_tokens(image)
else:
return estimate_openai_image_tokens(image, detail)
def estimate_total_tokens(
prompt: str, images: list[Image.Image], model: str, detail: str = "high"
) -> int:
"""
Estimate total tokens for a prompt with images.
Args:
prompt: Text prompt
images: List of PIL Images
model: Model string
detail: Detail level for OpenAI models
Returns:
Estimated total token count
"""
# Estimate text tokens
text_tokens = approx_token_count(prompt)
# Estimate image tokens
image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
return text_tokens + image_tokens

View File

@ -1,5 +1,6 @@
import pytest
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
from memory.common.tokens import CHARS_PER_TOKEN
@pytest.mark.parametrize(
@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
(" ", []), # Just spaces
("\n\t ", []), # Whitespace characters
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
]
],
)
def test_yield_word_chunk_basic_behavior(text, expected):
"""Test basic behavior of yield_word_chunks with various inputs"""
@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
[
(
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
),
(
"supercalifragilisticexpialidocious",
["supercalifragilisticexpialidocious"],
)
]
),
],
)
def test_yield_word_chunk_long_text(text, expected):
"""Test chunking with long text that exceeds token limits"""
@ -53,7 +54,12 @@ def test_yield_word_chunk_small_token_limit():
text = "one two three four five"
max_tokens = 1 # Very small to force chunking after each word
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
assert list(yield_word_chunks(text, max_tokens)) == [
"one two",
"three",
"four",
"five",
]
@pytest.mark.parametrize(
@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
(
"word1 word2", # 11 chars with space
3, # 12 chars limit
["word1 word2"]
["word1 word2"],
),
# Text just over token limit should split
(
"word1 word2 word3", # 17 chars with spaces
4, # 16 chars limit
["word1 word2 word3"]
["word1 word2 word3"],
),
# Each word exactly at token limit
(
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
1, # 1 token limit (4 chars)
["aaaa", "bbbb", "cccc"]
["aaaa", "bbbb", "cccc"],
),
]
],
)
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
"""Test different combinations of text and token limits"""
@ -98,12 +104,12 @@ def test_yield_word_chunk_real_world_example():
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
assert list(yield_word_chunks(text, max_tokens)) == [
'The yield_word_chunks function splits text',
'into chunks based on word boundaries. It',
'tries to maximize chunk size while staying',
'under the specified token limit. This',
'behavior is essential for processing large',
'documents efficiently.',
"The yield_word_chunks function splits text",
"into chunks based on word boundaries. It",
"tries to maximize chunk size while staying",
"under the specified token limit. This",
"behavior is essential for processing large",
"documents efficiently.",
]
@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
"text, expected",
[
("", []), # Empty text should yield nothing
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
(
"Simple paragraph",
["Simple paragraph"],
), # Single paragraph under token limit
(" ", []), # Just whitespace
]
],
)
def test_yield_spans_basic_behavior(text, expected):
"""Test basic behavior of yield_spans with various inputs"""
@ -146,7 +155,13 @@ def test_yield_spans_words():
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
long_sentence = "This sentence has several words and needs word-level chunking."
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
assert list(yield_spans(long_sentence, max_tokens)) == [
"This sentence",
"has several",
"words and needs",
"word-level",
"chunking.",
]
def test_yield_spans_complex_document():
@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
"Some are short.",
"This one is longer and might need word",
"splitting depending on the limit.",
"Final short paragraph."
"Final short paragraph.",
]
@ -183,7 +198,11 @@ def test_yield_spans_with_punctuation():
"""Test sentence splitting with various punctuation"""
text = "First sentence! Second sentence? Third sentence."
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
assert list(yield_spans(text, max_tokens=10)) == [
"First sentence!",
"Second sentence?",
"Third sentence.",
]
def test_yield_spans_edge_cases():
@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
("", []), # Empty text
("Short text", ["Short text"]), # Text below token limit
(" ", []), # Just whitespace
]
],
)
def test_chunk_text_basic_behavior(text, expected):
"""Test basic behavior of chunk_text with various inputs"""
@ -226,10 +245,8 @@ def test_chunk_text_long_text():
max_tokens = 10 # 10 tokens = ~40 chars
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
] + [
'This is sentence 49.'
]
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
] + ["This is sentence 49."]
def test_chunk_text_with_overlap():
@ -237,7 +254,11 @@ def test_chunk_text_with_overlap():
# Create text with distinct parts to test overlap
text = "Part A. Part B. Part C. Part D. Part E."
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
"Part A. Part B. Part C.",
"Part C. Part D. Part E.",
"Part E.",
]
def test_chunk_text_zero_overlap():
@ -245,7 +266,11 @@ def test_chunk_text_zero_overlap():
text = "Part A. Part B. Part C. Part D. Part E."
# 2 tokens = ~8 chars
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
"Part A. Part B.",
"Part C. Part D.",
"Part E.",
]
def test_chunk_text_clean_break():
@ -253,7 +278,10 @@ def test_chunk_text_clean_break():
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
max_tokens = 5 # Enough for about 2 sentences
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
"First sentence. Second sentence.",
"Third sentence. Fourth sentence.",
]
def test_chunk_text_very_long_sentences():
@ -262,24 +290,10 @@ def test_chunk_text_very_long_sentences():
max_tokens = 5 # Small limit to force splitting
assert list(chunk_text(text, max_tokens=max_tokens)) == [
'This is a very long sentence with many many',
'words that will definitely exceed the',
'token limit we set for',
'this particular test',
'case and should be split into multiple',
'chunks by the function.',
"This is a very long sentence with many many",
"words that will definitely exceed the",
"token limit we set for",
"this particular test",
"case and should be split into multiple",
"chunks by the function.",
]
@pytest.mark.parametrize(
"string, expected_count",
[
("", 0),
("a" * CHARS_PER_TOKEN, 1),
("a" * (CHARS_PER_TOKEN * 2), 2),
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
]
)
def test_approx_token_count(string, expected_count):
assert approx_token_count(string) == expected_count

View File

@ -0,0 +1,16 @@
import pytest
from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
@pytest.mark.parametrize(
"string, expected_count",
[
("", 0),
("a" * CHARS_PER_TOKEN, 1),
("a" * (CHARS_PER_TOKEN * 2), 2),
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
],
)
def test_approx_token_count(string, expected_count):
assert approx_token_count(string) == expected_count