diff --git a/frontend/src/components/search/Search.tsx b/frontend/src/components/search/Search.tsx
index 02302ca..96c4461 100644
--- a/frontend/src/components/search/Search.tsx
+++ b/frontend/src/components/search/Search.tsx
@@ -9,6 +9,7 @@ const SearchResults = ({ results, isLoading }: { results: any[], isLoading: bool
if (isLoading) {
return
}
+ console.log("results",results)
return (
{results.length > 0 && (
diff --git a/src/memory/api/search/scorer.py b/src/memory/api/search/scorer.py
new file mode 100644
index 0000000..bf40315
--- /dev/null
+++ b/src/memory/api/search/scorer.py
@@ -0,0 +1,65 @@
+import asyncio
+from bs4 import BeautifulSoup
+from PIL import Image
+
+from memory.common.db.models.source_item import Chunk
+from memory.common import llms, settings, tokens
+
+
+SCORE_CHUNK_SYSTEM_PROMPT = """
+You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
+
+You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
+"""
+
+SCORE_CHUNK_PROMPT = """
+Here is the query:
+{query}
+
+Here is the chunk:
+
+ {chunk}
+
+
+Please return your score as a number between 0 and 1 formatted as:
+your score
+
+Please always return a summary of any images provided.
+"""
+
+
+async def score_chunk(query: str, chunk: Chunk) -> Chunk:
+ data = chunk.data
+ chunk_text = "\n".join(text for text in data if isinstance(text, str))
+ images = [image for image in data if isinstance(image, Image.Image)]
+ prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
+ response = await asyncio.to_thread(
+ llms.call,
+ prompt,
+ settings.RANKER_MODEL,
+ images=images,
+ system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
+ )
+
+ soup = BeautifulSoup(response, "html.parser")
+ if not (score := soup.find("score")):
+ chunk.relevance_score = 0.0
+ else:
+ try:
+ chunk.relevance_score = float(score.text.strip())
+ except ValueError:
+ chunk.relevance_score = 0.0
+
+ return chunk
+
+
+async def rank_chunks(
+ query: str, chunks: list[Chunk], min_score: float = 0
+) -> list[Chunk]:
+ calls = [score_chunk(query, chunk) for chunk in chunks]
+ scored = await asyncio.gather(*calls)
+ return sorted(
+ [chunk for chunk in scored if chunk.relevance_score >= min_score],
+ key=lambda x: x.relevance_score or 0,
+ reverse=True,
+ )
diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py
index 85b39c0..0df208c 100644
--- a/src/memory/api/search/search.py
+++ b/src/memory/api/search/search.py
@@ -12,6 +12,7 @@ from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.common.collections import ALL_COLLECTIONS
from memory.api.search.embeddings import search_chunks_embeddings
+from memory.api.search import scorer
if settings.ENABLE_BM25_SEARCH:
from memory.api.search.bm25 import search_bm25_chunks
@@ -40,7 +41,14 @@ async def search_chunks(
with make_session() as db:
chunks = (
db.query(Chunk)
- .options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
+ .options(
+ load_only(
+ Chunk.id, # type: ignore
+ Chunk.source_id, # type: ignore
+ Chunk.content, # type: ignore
+ Chunk.file_paths, # type: ignore
+ )
+ )
.filter(Chunk.id.in_(all_ids))
.all()
)
@@ -91,4 +99,6 @@ async def search(
filters,
timeout,
)
+ if settings.ENABLE_SEARCH_SCORING:
+ chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
return await search_sources(chunks, previews)
diff --git a/src/memory/common/chunker.py b/src/memory/common/chunker.py
index 71bb322..a6bdfcb 100644
--- a/src/memory/common/chunker.py
+++ b/src/memory/common/chunker.py
@@ -2,7 +2,7 @@ import logging
from typing import Iterable, Any
import re
-from memory.common import settings
+from memory.common import settings, tokens
logger = logging.getLogger(__name__)
@@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
-CHARS_PER_TOKEN = 4
Vector = list[float]
@@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
-def approx_token_count(s: str) -> int:
- return len(s) // CHARS_PER_TOKEN
-
-
def yield_word_chunks(
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
) -> Iterable[str]:
@@ -36,7 +31,7 @@ def yield_word_chunks(
current = ""
for word in words:
new_chunk = f"{current} {word}".strip()
- if current and approx_token_count(new_chunk) > max_tokens:
+ if current and tokens.approx_token_count(new_chunk) > max_tokens:
yield current
current = word
else:
@@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
if not paragraph.strip():
continue
- if approx_token_count(paragraph) <= max_tokens:
+ if tokens.approx_token_count(paragraph) <= max_tokens:
yield paragraph
continue
@@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
if not sentence.strip():
continue
- if approx_token_count(sentence) <= max_tokens:
+ if tokens.approx_token_count(sentence) <= max_tokens:
yield sentence
continue
@@ -99,16 +94,16 @@ def chunk_text(
if not text:
return
- if approx_token_count(text) <= max_tokens:
+ if tokens.approx_token_count(text) <= max_tokens:
yield text
return
- overlap_chars = overlap * CHARS_PER_TOKEN
+ overlap_chars = overlap * tokens.CHARS_PER_TOKEN
current = ""
for span in yield_spans(text, max_tokens):
current = f"{current} {span}".strip()
- if approx_token_count(current) < max_tokens:
+ if tokens.approx_token_count(current) < max_tokens:
continue
if overlap <= 0:
diff --git a/src/memory/common/db/models/source_item.py b/src/memory/common/db/models/source_item.py
index 0035921..42d2280 100644
--- a/src/memory/common/db/models/source_item.py
+++ b/src/memory/common/db/models/source_item.py
@@ -156,9 +156,11 @@ class Chunk(Base):
collection_name = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now())
checked_at = Column(DateTime(timezone=True), server_default=func.now())
+
vector: list[float] = []
item_metadata: dict[str, Any] = {}
images: list[Image.Image] = []
+ relevance_score: float = 0.0
# One of file_path or content must be populated
__table_args__ = (
diff --git a/src/memory/common/llms.py b/src/memory/common/llms.py
new file mode 100644
index 0000000..8b88b75
--- /dev/null
+++ b/src/memory/common/llms.py
@@ -0,0 +1,122 @@
+import logging
+import base64
+import io
+from typing import Any
+from PIL import Image
+
+from memory.common import settings, tokens
+
+logger = logging.getLogger(__name__)
+
+SYSTEM_PROMPT = """
+You are a helpful assistant that creates concise summaries and identifies key topics.
+"""
+
+
+def encode_image(image: Image.Image) -> str:
+ """Encode PIL Image to base64 string."""
+ buffer = io.BytesIO()
+ # Convert to RGB if necessary (for RGBA, etc.)
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+ image.save(buffer, format="JPEG")
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
+
+
+def call_openai(
+ prompt: str,
+ model: str,
+ images: list[Image.Image] = [],
+ system_prompt: str = SYSTEM_PROMPT,
+) -> str:
+ """Call OpenAI API for summarization."""
+ import openai
+
+ client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
+ try:
+ user_content: Any = [{"type": "text", "text": prompt}]
+ if images:
+ for image in images:
+ encoded_image = encode_image(image)
+ user_content.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
+ }
+ )
+
+ response = client.chat.completions.create(
+ model=model.split("/")[1],
+ messages=[
+ {
+ "role": "system",
+ "content": system_prompt,
+ },
+ {"role": "user", "content": user_content},
+ ],
+ temperature=0.3,
+ max_tokens=2048,
+ )
+ return response.choices[0].message.content or ""
+ except Exception as e:
+ logger.error(f"OpenAI API error: {e}")
+ raise
+
+
+def call_anthropic(
+ prompt: str,
+ model: str,
+ images: list[Image.Image] = [],
+ system_prompt: str = SYSTEM_PROMPT,
+) -> str:
+ """Call Anthropic API for summarization."""
+ import anthropic
+
+ client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
+ try:
+ # Prepare the message content
+ content: Any = [{"type": "text", "text": prompt}]
+ if images:
+ # Add images if provided
+ for image in images:
+ encoded_image = encode_image(image)
+ content.append(
+ { # type: ignore
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": encoded_image,
+ },
+ }
+ )
+
+ response = client.messages.create(
+ model=model.split("/")[1],
+ messages=[{"role": "user", "content": content}], # type: ignore
+ system=system_prompt,
+ temperature=0.3,
+ max_tokens=2048,
+ )
+ return response.content[0].text
+ except Exception as e:
+ logger.error(f"Anthropic API error: {e}")
+ raise
+
+
+def call(
+ prompt: str,
+ model: str,
+ images: list[Image.Image] = [],
+ system_prompt: str = SYSTEM_PROMPT,
+) -> str:
+ if model.startswith("anthropic"):
+ return call_anthropic(prompt, model, images, system_prompt)
+ return call_openai(prompt, model, images, system_prompt)
+
+
+def truncate(content: str, target_tokens: int) -> str:
+ target_chars = target_tokens * tokens.CHARS_PER_TOKEN
+ if len(content) > target_chars:
+ return content[:target_chars].rsplit(" ", 1)[0] + "..."
+ return content
diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py
index 2943506..cae5fbd 100644
--- a/src/memory/common/settings.py
+++ b/src/memory/common/settings.py
@@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
+RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
+MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
# Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
+ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
diff --git a/src/memory/common/summarizer.py b/src/memory/common/summarizer.py
index 9843649..c2016f3 100644
--- a/src/memory/common/summarizer.py
+++ b/src/memory/common/summarizer.py
@@ -4,7 +4,7 @@ from typing import Any
from bs4 import BeautifulSoup
-from memory.common import settings, chunker
+from memory.common import settings, tokens, llms
logger = logging.getLogger(__name__)
@@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
return {"summary": summary, "tags": tags}
-def _call_openai(prompt: str) -> dict[str, Any]:
- """Call OpenAI API for summarization."""
- import openai
-
- client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
- try:
- response = client.chat.completions.create(
- model=settings.SUMMARIZER_MODEL.split("/")[1],
- messages=[
- {
- "role": "system",
- "content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
- },
- {"role": "user", "content": prompt},
- ],
- temperature=0.3,
- max_tokens=2048,
- )
- return parse_response(response.choices[0].message.content or "")
- except Exception as e:
- logger.error(f"OpenAI API error: {e}")
- raise
-
-
-def _call_anthropic(prompt: str) -> dict[str, Any]:
- """Call Anthropic API for summarization."""
- import anthropic
-
- client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
- try:
- response = client.messages.create(
- model=settings.SUMMARIZER_MODEL.split("/")[1],
- messages=[{"role": "user", "content": prompt}],
- system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
- temperature=0.3,
- max_tokens=2048,
- )
- return parse_response(response.content[0].text)
- except Exception as e:
- logger.error(f"Anthropic API error: {e}")
- logger.error(response.content[0].text)
- raise
-
-
-def truncate(content: str, target_tokens: int) -> str:
- target_chars = target_tokens * chunker.CHARS_PER_TOKEN
- if len(content) > target_chars:
- return content[:target_chars].rsplit(" ", 1)[0] + "..."
- return content
-
-
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
"""
Summarize content to approximately target_tokens length and generate tags.
@@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
summary, tags = content, []
# If content is already short enough, just extract tags
- current_tokens = chunker.approx_token_count(content)
+ current_tokens = tokens.approx_token_count(content)
if current_tokens <= target_tokens:
logger.info(
f"Content already under {target_tokens} tokens, extracting tags only"
@@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
else:
prompt = SUMMARY_PROMPT.format(
target_tokens=target_tokens,
- target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
+ target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
content=content,
)
- if chunker.approx_token_count(prompt) > MAX_TOKENS:
+ if tokens.approx_token_count(prompt) > MAX_TOKENS:
logger.warning(
- f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating"
+ f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
)
- prompt = truncate(prompt, MAX_TOKENS - 20)
+ prompt = llms.truncate(prompt, MAX_TOKENS - 20)
try:
- if settings.SUMMARIZER_MODEL.startswith("anthropic"):
- result = _call_anthropic(prompt)
- else:
- result = _call_openai(prompt)
+ response = llms.call(prompt, settings.SUMMARIZER_MODEL)
+ result = parse_response(response)
summary = result.get("summary", "")
tags = result.get("tags", [])
@@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
traceback.print_exc()
logger.error(f"Summarization failed: {e}")
- tokens = chunker.approx_token_count(summary)
- if tokens > target_tokens * 1.5:
- logger.warning(f"Summary too long ({tokens} tokens), truncating")
- summary = truncate(content, target_tokens)
+ summary_tokens = tokens.approx_token_count(summary)
+ if summary_tokens > target_tokens * 1.5:
+ logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
+ summary = llms.truncate(content, target_tokens)
return summary, tags
diff --git a/src/memory/common/tokens.py b/src/memory/common/tokens.py
new file mode 100644
index 0000000..4f2d942
--- /dev/null
+++ b/src/memory/common/tokens.py
@@ -0,0 +1,128 @@
+import logging
+from PIL import Image
+import math
+
+logger = logging.getLogger(__name__)
+
+
+CHARS_PER_TOKEN = 4
+
+
+def approx_token_count(s: str) -> int:
+ return len(s) // CHARS_PER_TOKEN
+
+
+def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
+ """
+ Estimate tokens for an image using OpenAI's counting method.
+
+ Args:
+ image: PIL Image
+ detail: "high" or "low" detail level
+
+ Returns:
+ Estimated token count
+ """
+ if detail == "low":
+ return 85
+
+ # For high detail, OpenAI resizes the image to fit within 2048x2048
+ # while maintaining aspect ratio, then counts 512x512 tiles
+ width, height = image.size
+
+ # Resize logic to fit within 2048x2048
+ if width > 2048 or height > 2048:
+ if width > height:
+ height = int(height * 2048 / width)
+ width = 2048
+ else:
+ width = int(width * 2048 / height)
+ height = 2048
+
+ # Further resize so shortest side is 768px
+ if width < height:
+ if width > 768:
+ height = int(height * 768 / width)
+ width = 768
+ else:
+ if height > 768:
+ width = int(width * 768 / height)
+ height = 768
+
+ # Count 512x512 tiles
+ tiles_width = math.ceil(width / 512)
+ tiles_height = math.ceil(height / 512)
+ total_tiles = tiles_width * tiles_height
+
+ # Each tile costs 170 tokens, plus 85 base tokens
+ return total_tiles * 170 + 85
+
+
+def estimate_anthropic_image_tokens(image: Image.Image) -> int:
+ """
+ Estimate tokens for an image using Anthropic's counting method.
+
+ Args:
+ image: PIL Image
+
+ Returns:
+ Estimated token count
+ """
+ width, height = image.size
+
+ # Anthropic's token counting is based on image dimensions
+ # They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
+ # But they also have a base cost per image
+
+ # Rough approximation based on Anthropic's documentation
+ # They count tokens based on the image size after potential resizing
+ total_pixels = width * height
+
+ # Anthropic typically charges around 1.15 tokens per 1000 pixels
+ # with a minimum base cost
+ base_tokens = 100 # Base cost for any image
+ pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
+
+ return base_tokens + pixel_tokens
+
+
+def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
+ """
+ Estimate tokens for an image based on the model provider.
+
+ Args:
+ image: PIL Image
+ model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
+ detail: Detail level for OpenAI models ("high" or "low")
+
+ Returns:
+ Estimated token count
+ """
+ if model.startswith("anthropic"):
+ return estimate_anthropic_image_tokens(image)
+ else:
+ return estimate_openai_image_tokens(image, detail)
+
+
+def estimate_total_tokens(
+ prompt: str, images: list[Image.Image], model: str, detail: str = "high"
+) -> int:
+ """
+ Estimate total tokens for a prompt with images.
+
+ Args:
+ prompt: Text prompt
+ images: List of PIL Images
+ model: Model string
+ detail: Detail level for OpenAI models
+
+ Returns:
+ Estimated total token count
+ """
+ # Estimate text tokens
+ text_tokens = approx_token_count(prompt)
+
+ # Estimate image tokens
+ image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
+
+ return text_tokens + image_tokens
diff --git a/tests/memory/common/test_chunker.py b/tests/memory/common/test_chunker.py
index 3380dc4..ca6ec77 100644
--- a/tests/memory/common/test_chunker.py
+++ b/tests/memory/common/test_chunker.py
@@ -1,5 +1,6 @@
import pytest
-from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
+from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
+from memory.common.tokens import CHARS_PER_TOKEN
@pytest.mark.parametrize(
@@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
(" ", []), # Just spaces
("\n\t ", []), # Whitespace characters
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
- ]
+ ],
)
def test_yield_word_chunk_basic_behavior(text, expected):
"""Test basic behavior of yield_word_chunks with various inputs"""
@@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
[
(
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
- ['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
+ ["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
),
(
"supercalifragilisticexpialidocious",
["supercalifragilisticexpialidocious"],
- )
- ]
+ ),
+ ],
)
def test_yield_word_chunk_long_text(text, expected):
"""Test chunking with long text that exceeds token limits"""
@@ -41,7 +42,7 @@ def test_yield_word_chunk_single_long_word():
"""Test behavior with a single word longer than the token limit"""
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
-
+
chunks = list(yield_word_chunks(long_word, max_tokens))
# With our changes, this should be a single chunk
assert len(chunks) == 1
@@ -52,8 +53,13 @@ def test_yield_word_chunk_small_token_limit():
"""Test with a very small max_tokens value to force chunking"""
text = "one two three four five"
max_tokens = 1 # Very small to force chunking after each word
-
- assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
+
+ assert list(yield_word_chunks(text, max_tokens)) == [
+ "one two",
+ "three",
+ "four",
+ "five",
+ ]
@pytest.mark.parametrize(
@@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
(
"word1 word2", # 11 chars with space
3, # 12 chars limit
- ["word1 word2"]
+ ["word1 word2"],
),
# Text just over token limit should split
(
"word1 word2 word3", # 17 chars with spaces
4, # 16 chars limit
- ["word1 word2 word3"]
+ ["word1 word2 word3"],
),
# Each word exactly at token limit
(
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
1, # 1 token limit (4 chars)
- ["aaaa", "bbbb", "cccc"]
+ ["aaaa", "bbbb", "cccc"],
),
- ]
+ ],
)
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
"""Test different combinations of text and token limits"""
@@ -95,15 +101,15 @@ def test_yield_word_chunk_real_world_example():
"It tries to maximize chunk size while staying under the specified token limit. "
"This behavior is essential for processing large documents efficiently."
)
-
+
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
assert list(yield_word_chunks(text, max_tokens)) == [
- 'The yield_word_chunks function splits text',
- 'into chunks based on word boundaries. It',
- 'tries to maximize chunk size while staying',
- 'under the specified token limit. This',
- 'behavior is essential for processing large',
- 'documents efficiently.',
+ "The yield_word_chunks function splits text",
+ "into chunks based on word boundaries. It",
+ "tries to maximize chunk size while staying",
+ "under the specified token limit. This",
+ "behavior is essential for processing large",
+ "documents efficiently.",
]
@@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
"text, expected",
[
("", []), # Empty text should yield nothing
- ("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
+ (
+ "Simple paragraph",
+ ["Simple paragraph"],
+ ), # Single paragraph under token limit
(" ", []), # Just whitespace
- ]
+ ],
)
def test_yield_spans_basic_behavior(text, expected):
"""Test basic behavior of yield_spans with various inputs"""
@@ -135,7 +144,7 @@ def test_yield_spans_sentences():
sentence1 = "Short sentence one." # ~20 chars
sentence2 = "Another short sentence." # ~24 chars
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
-
+
# Function should now preserve punctuation
expected = ["Short sentence one.", "Another short sentence."]
assert list(yield_spans(text, max_tokens)) == expected
@@ -145,8 +154,14 @@ def test_yield_spans_words():
"""Test splitting by words when sentences exceed token limit"""
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
long_sentence = "This sentence has several words and needs word-level chunking."
-
- assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
+
+ assert list(yield_spans(long_sentence, max_tokens)) == [
+ "This sentence",
+ "has several",
+ "words and needs",
+ "word-level",
+ "chunking.",
+ ]
def test_yield_spans_complex_document():
@@ -158,7 +173,7 @@ def test_yield_spans_complex_document():
"This one is longer and might need word splitting depending on the limit.\n\n"
"Final short paragraph."
)
-
+
assert list(yield_spans(text, max_tokens)) == [
"Paragraph one with a short sentence.",
"And another sentence that should be split.",
@@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
"Some are short.",
"This one is longer and might need word",
"splitting depending on the limit.",
- "Final short paragraph."
+ "Final short paragraph.",
]
@@ -175,21 +190,25 @@ def test_yield_spans_very_long_word():
"""Test with a word that exceeds the token limit"""
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
-
+
assert list(yield_spans(long_word, max_tokens)) == [long_word]
def test_yield_spans_with_punctuation():
"""Test sentence splitting with various punctuation"""
text = "First sentence! Second sentence? Third sentence."
-
- assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
+
+ assert list(yield_spans(text, max_tokens=10)) == [
+ "First sentence!",
+ "Second sentence?",
+ "Third sentence.",
+ ]
def test_yield_spans_edge_cases():
"""Test edge cases like empty paragraphs, single character paragraphs"""
text = "\n\nA\n\n\n\nB\n\n"
-
+
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
@@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
("", []), # Empty text
("Short text", ["Short text"]), # Text below token limit
(" ", []), # Just whitespace
- ]
+ ],
)
def test_chunk_text_basic_behavior(text, expected):
"""Test basic behavior of chunk_text with various inputs"""
@@ -223,63 +242,58 @@ def test_chunk_text_long_text():
# Create a long text that will need multiple chunks
sentences = [f"This is sentence {i:02}." for i in range(50)]
text = " ".join(sentences)
-
+
max_tokens = 10 # 10 tokens = ~40 chars
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
- f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
- ] + [
- 'This is sentence 49.'
- ]
-
+ f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
+ ] + ["This is sentence 49."]
+
def test_chunk_text_with_overlap():
"""Test chunking with overlap between chunks"""
# Create text with distinct parts to test overlap
text = "Part A. Part B. Part C. Part D. Part E."
-
- assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
+
+ assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
+ "Part A. Part B. Part C.",
+ "Part C. Part D. Part E.",
+ "Part E.",
+ ]
def test_chunk_text_zero_overlap():
"""Test chunking with zero overlap"""
text = "Part A. Part B. Part C. Part D. Part E."
-
+
# 2 tokens = ~8 chars
- assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
+ assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
+ "Part A. Part B.",
+ "Part C. Part D.",
+ "Part E.",
+ ]
def test_chunk_text_clean_break():
"""Test that chunking attempts to break at sentence boundaries"""
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
-
+
max_tokens = 5 # Enough for about 2 sentences
- assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
+ assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
+ "First sentence. Second sentence.",
+ "Third sentence. Fourth sentence.",
+ ]
def test_chunk_text_very_long_sentences():
"""Test with very long sentences that exceed the token limit"""
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
-
+
max_tokens = 5 # Small limit to force splitting
assert list(chunk_text(text, max_tokens=max_tokens)) == [
- 'This is a very long sentence with many many',
- 'words that will definitely exceed the',
- 'token limit we set for',
- 'this particular test',
- 'case and should be split into multiple',
- 'chunks by the function.',
+ "This is a very long sentence with many many",
+ "words that will definitely exceed the",
+ "token limit we set for",
+ "this particular test",
+ "case and should be split into multiple",
+ "chunks by the function.",
]
-
-
-@pytest.mark.parametrize(
- "string, expected_count",
- [
- ("", 0),
- ("a" * CHARS_PER_TOKEN, 1),
- ("a" * (CHARS_PER_TOKEN * 2), 2),
- ("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
- ("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
- ]
-)
-def test_approx_token_count(string, expected_count):
- assert approx_token_count(string) == expected_count
diff --git a/tests/memory/common/test_tokens.py b/tests/memory/common/test_tokens.py
new file mode 100644
index 0000000..82be607
--- /dev/null
+++ b/tests/memory/common/test_tokens.py
@@ -0,0 +1,16 @@
+import pytest
+from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
+
+
+@pytest.mark.parametrize(
+ "string, expected_count",
+ [
+ ("", 0),
+ ("a" * CHARS_PER_TOKEN, 1),
+ ("a" * (CHARS_PER_TOKEN * 2), 2),
+ ("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
+ ("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
+ ],
+)
+def test_approx_token_count(string, expected_count):
+ assert approx_token_count(string) == expected_count