mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 23:24:43 +02:00
second pass in search
This commit is contained in:
parent
06eec621c1
commit
8eb6374cac
@ -9,6 +9,7 @@ const SearchResults = ({ results, isLoading }: { results: any[], isLoading: bool
|
||||
if (isLoading) {
|
||||
return <Loading message="Searching..." />
|
||||
}
|
||||
console.log("results",results)
|
||||
return (
|
||||
<div className="search-results">
|
||||
{results.length > 0 && (
|
||||
|
65
src/memory/api/search/scorer.py
Normal file
65
src/memory/api/search/scorer.py
Normal file
@ -0,0 +1,65 @@
|
||||
import asyncio
|
||||
from bs4 import BeautifulSoup
|
||||
from PIL import Image
|
||||
|
||||
from memory.common.db.models.source_item import Chunk
|
||||
from memory.common import llms, settings, tokens
|
||||
|
||||
|
||||
SCORE_CHUNK_SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
||||
|
||||
You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
|
||||
"""
|
||||
|
||||
SCORE_CHUNK_PROMPT = """
|
||||
Here is the query:
|
||||
<query>{query}</query>
|
||||
|
||||
Here is the chunk:
|
||||
<chunk>
|
||||
{chunk}
|
||||
</chunk>
|
||||
|
||||
Please return your score as a number between 0 and 1 formatted as:
|
||||
<score>your score</score>
|
||||
|
||||
Please always return a summary of any images provided.
|
||||
"""
|
||||
|
||||
|
||||
async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||
data = chunk.data
|
||||
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
||||
images = [image for image in data if isinstance(image, Image.Image)]
|
||||
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||
response = await asyncio.to_thread(
|
||||
llms.call,
|
||||
prompt,
|
||||
settings.RANKER_MODEL,
|
||||
images=images,
|
||||
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
soup = BeautifulSoup(response, "html.parser")
|
||||
if not (score := soup.find("score")):
|
||||
chunk.relevance_score = 0.0
|
||||
else:
|
||||
try:
|
||||
chunk.relevance_score = float(score.text.strip())
|
||||
except ValueError:
|
||||
chunk.relevance_score = 0.0
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
async def rank_chunks(
|
||||
query: str, chunks: list[Chunk], min_score: float = 0
|
||||
) -> list[Chunk]:
|
||||
calls = [score_chunk(query, chunk) for chunk in chunks]
|
||||
scored = await asyncio.gather(*calls)
|
||||
return sorted(
|
||||
[chunk for chunk in scored if chunk.relevance_score >= min_score],
|
||||
key=lambda x: x.relevance_score or 0,
|
||||
reverse=True,
|
||||
)
|
@ -12,6 +12,7 @@ from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.collections import ALL_COLLECTIONS
|
||||
from memory.api.search.embeddings import search_chunks_embeddings
|
||||
from memory.api.search import scorer
|
||||
|
||||
if settings.ENABLE_BM25_SEARCH:
|
||||
from memory.api.search.bm25 import search_bm25_chunks
|
||||
@ -40,7 +41,14 @@ async def search_chunks(
|
||||
with make_session() as db:
|
||||
chunks = (
|
||||
db.query(Chunk)
|
||||
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
|
||||
.options(
|
||||
load_only(
|
||||
Chunk.id, # type: ignore
|
||||
Chunk.source_id, # type: ignore
|
||||
Chunk.content, # type: ignore
|
||||
Chunk.file_paths, # type: ignore
|
||||
)
|
||||
)
|
||||
.filter(Chunk.id.in_(all_ids))
|
||||
.all()
|
||||
)
|
||||
@ -91,4 +99,6 @@ async def search(
|
||||
filters,
|
||||
timeout,
|
||||
)
|
||||
if settings.ENABLE_SEARCH_SCORING:
|
||||
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||
return await search_sources(chunks, previews)
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Iterable, Any
|
||||
import re
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common import settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
|
||||
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
||||
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
||||
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
Vector = list[float]
|
||||
@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
|
||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def yield_word_chunks(
|
||||
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
||||
) -> Iterable[str]:
|
||||
@ -36,7 +31,7 @@ def yield_word_chunks(
|
||||
current = ""
|
||||
for word in words:
|
||||
new_chunk = f"{current} {word}".strip()
|
||||
if current and approx_token_count(new_chunk) > max_tokens:
|
||||
if current and tokens.approx_token_count(new_chunk) > max_tokens:
|
||||
yield current
|
||||
current = word
|
||||
else:
|
||||
@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
||||
if not paragraph.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(paragraph) <= max_tokens:
|
||||
if tokens.approx_token_count(paragraph) <= max_tokens:
|
||||
yield paragraph
|
||||
continue
|
||||
|
||||
@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(sentence) <= max_tokens:
|
||||
if tokens.approx_token_count(sentence) <= max_tokens:
|
||||
yield sentence
|
||||
continue
|
||||
|
||||
@ -99,16 +94,16 @@ def chunk_text(
|
||||
if not text:
|
||||
return
|
||||
|
||||
if approx_token_count(text) <= max_tokens:
|
||||
if tokens.approx_token_count(text) <= max_tokens:
|
||||
yield text
|
||||
return
|
||||
|
||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
||||
overlap_chars = overlap * tokens.CHARS_PER_TOKEN
|
||||
current = ""
|
||||
|
||||
for span in yield_spans(text, max_tokens):
|
||||
current = f"{current} {span}".strip()
|
||||
if approx_token_count(current) < max_tokens:
|
||||
if tokens.approx_token_count(current) < max_tokens:
|
||||
continue
|
||||
|
||||
if overlap <= 0:
|
||||
|
@ -156,9 +156,11 @@ class Chunk(Base):
|
||||
collection_name = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
vector: list[float] = []
|
||||
item_metadata: dict[str, Any] = {}
|
||||
images: list[Image.Image] = []
|
||||
relevance_score: float = 0.0
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
|
122
src/memory/common/llms.py
Normal file
122
src/memory/common/llms.py
Normal file
@ -0,0 +1,122 @@
|
||||
import logging
|
||||
import base64
|
||||
import io
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that creates concise summaries and identifies key topics.
|
||||
"""
|
||||
|
||||
|
||||
def encode_image(image: Image.Image) -> str:
|
||||
"""Encode PIL Image to base64 string."""
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if necessary (for RGBA, etc.)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format="JPEG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def call_openai(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
user_content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
user_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call_anthropic(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
# Prepare the message content
|
||||
content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
# Add images if provided
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
content.append(
|
||||
{ # type: ignore
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": encoded_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[{"role": "user", "content": content}], # type: ignore
|
||||
system=system_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
if model.startswith("anthropic"):
|
||||
return call_anthropic(prompt, model, images, system_prompt)
|
||||
return call_openai(prompt, model, images, system_prompt)
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
||||
else:
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||
|
||||
# Search settings
|
||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from memory.common import settings, chunker
|
||||
from memory.common import settings, tokens, llms
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
|
||||
return {"summary": summary, "tags": tags}
|
||||
|
||||
|
||||
def _call_openai(prompt: str) -> dict[str, Any]:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return parse_response(response.choices[0].message.content or "")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _call_anthropic(prompt: str) -> dict[str, Any]:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return parse_response(response.content[0].text)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
logger.error(response.content[0].text)
|
||||
raise
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
||||
|
||||
|
||||
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Summarize content to approximately target_tokens length and generate tags.
|
||||
@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
summary, tags = content, []
|
||||
|
||||
# If content is already short enough, just extract tags
|
||||
current_tokens = chunker.approx_token_count(content)
|
||||
current_tokens = tokens.approx_token_count(content)
|
||||
if current_tokens <= target_tokens:
|
||||
logger.info(
|
||||
f"Content already under {target_tokens} tokens, extracting tags only"
|
||||
@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
else:
|
||||
prompt = SUMMARY_PROMPT.format(
|
||||
target_tokens=target_tokens,
|
||||
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
|
||||
target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
|
||||
content=content,
|
||||
)
|
||||
|
||||
if chunker.approx_token_count(prompt) > MAX_TOKENS:
|
||||
if tokens.approx_token_count(prompt) > MAX_TOKENS:
|
||||
logger.warning(
|
||||
f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating"
|
||||
f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
|
||||
)
|
||||
prompt = truncate(prompt, MAX_TOKENS - 20)
|
||||
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||
|
||||
try:
|
||||
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
|
||||
result = _call_anthropic(prompt)
|
||||
else:
|
||||
result = _call_openai(prompt)
|
||||
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
||||
result = parse_response(response)
|
||||
|
||||
summary = result.get("summary", "")
|
||||
tags = result.get("tags", [])
|
||||
@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
traceback.print_exc()
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
tokens = chunker.approx_token_count(summary)
|
||||
if tokens > target_tokens * 1.5:
|
||||
logger.warning(f"Summary too long ({tokens} tokens), truncating")
|
||||
summary = truncate(content, target_tokens)
|
||||
summary_tokens = tokens.approx_token_count(summary)
|
||||
if summary_tokens > target_tokens * 1.5:
|
||||
logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
|
||||
summary = llms.truncate(content, target_tokens)
|
||||
|
||||
return summary, tags
|
||||
|
128
src/memory/common/tokens.py
Normal file
128
src/memory/common/tokens.py
Normal file
@ -0,0 +1,128 @@
|
||||
import logging
|
||||
from PIL import Image
|
||||
import math
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
|
||||
"""
|
||||
Estimate tokens for an image using OpenAI's counting method.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
detail: "high" or "low" detail level
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if detail == "low":
|
||||
return 85
|
||||
|
||||
# For high detail, OpenAI resizes the image to fit within 2048x2048
|
||||
# while maintaining aspect ratio, then counts 512x512 tiles
|
||||
width, height = image.size
|
||||
|
||||
# Resize logic to fit within 2048x2048
|
||||
if width > 2048 or height > 2048:
|
||||
if width > height:
|
||||
height = int(height * 2048 / width)
|
||||
width = 2048
|
||||
else:
|
||||
width = int(width * 2048 / height)
|
||||
height = 2048
|
||||
|
||||
# Further resize so shortest side is 768px
|
||||
if width < height:
|
||||
if width > 768:
|
||||
height = int(height * 768 / width)
|
||||
width = 768
|
||||
else:
|
||||
if height > 768:
|
||||
width = int(width * 768 / height)
|
||||
height = 768
|
||||
|
||||
# Count 512x512 tiles
|
||||
tiles_width = math.ceil(width / 512)
|
||||
tiles_height = math.ceil(height / 512)
|
||||
total_tiles = tiles_width * tiles_height
|
||||
|
||||
# Each tile costs 170 tokens, plus 85 base tokens
|
||||
return total_tiles * 170 + 85
|
||||
|
||||
|
||||
def estimate_anthropic_image_tokens(image: Image.Image) -> int:
|
||||
"""
|
||||
Estimate tokens for an image using Anthropic's counting method.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
width, height = image.size
|
||||
|
||||
# Anthropic's token counting is based on image dimensions
|
||||
# They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
|
||||
# But they also have a base cost per image
|
||||
|
||||
# Rough approximation based on Anthropic's documentation
|
||||
# They count tokens based on the image size after potential resizing
|
||||
total_pixels = width * height
|
||||
|
||||
# Anthropic typically charges around 1.15 tokens per 1000 pixels
|
||||
# with a minimum base cost
|
||||
base_tokens = 100 # Base cost for any image
|
||||
pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
|
||||
|
||||
return base_tokens + pixel_tokens
|
||||
|
||||
|
||||
def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
|
||||
"""
|
||||
Estimate tokens for an image based on the model provider.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
|
||||
detail: Detail level for OpenAI models ("high" or "low")
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if model.startswith("anthropic"):
|
||||
return estimate_anthropic_image_tokens(image)
|
||||
else:
|
||||
return estimate_openai_image_tokens(image, detail)
|
||||
|
||||
|
||||
def estimate_total_tokens(
|
||||
prompt: str, images: list[Image.Image], model: str, detail: str = "high"
|
||||
) -> int:
|
||||
"""
|
||||
Estimate total tokens for a prompt with images.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
images: List of PIL Images
|
||||
model: Model string
|
||||
detail: Detail level for OpenAI models
|
||||
|
||||
Returns:
|
||||
Estimated total token count
|
||||
"""
|
||||
# Estimate text tokens
|
||||
text_tokens = approx_token_count(prompt)
|
||||
|
||||
# Estimate image tokens
|
||||
image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
|
||||
|
||||
return text_tokens + image_tokens
|
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
|
||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
|
||||
from memory.common.tokens import CHARS_PER_TOKEN
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
|
||||
(" ", []), # Just spaces
|
||||
("\n\t ", []), # Whitespace characters
|
||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||
@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
[
|
||||
(
|
||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
||||
["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
|
||||
),
|
||||
(
|
||||
"supercalifragilisticexpialidocious",
|
||||
["supercalifragilisticexpialidocious"],
|
||||
)
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_long_text(text, expected):
|
||||
"""Test chunking with long text that exceeds token limits"""
|
||||
@ -41,7 +42,7 @@ def test_yield_word_chunk_single_long_word():
|
||||
"""Test behavior with a single word longer than the token limit"""
|
||||
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
|
||||
|
||||
|
||||
chunks = list(yield_word_chunks(long_word, max_tokens))
|
||||
# With our changes, this should be a single chunk
|
||||
assert len(chunks) == 1
|
||||
@ -52,8 +53,13 @@ def test_yield_word_chunk_small_token_limit():
|
||||
"""Test with a very small max_tokens value to force chunking"""
|
||||
text = "one two three four five"
|
||||
max_tokens = 1 # Very small to force chunking after each word
|
||||
|
||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
||||
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
"one two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
|
||||
(
|
||||
"word1 word2", # 11 chars with space
|
||||
3, # 12 chars limit
|
||||
["word1 word2"]
|
||||
["word1 word2"],
|
||||
),
|
||||
# Text just over token limit should split
|
||||
(
|
||||
"word1 word2 word3", # 17 chars with spaces
|
||||
4, # 16 chars limit
|
||||
["word1 word2 word3"]
|
||||
["word1 word2 word3"],
|
||||
),
|
||||
# Each word exactly at token limit
|
||||
(
|
||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||
1, # 1 token limit (4 chars)
|
||||
["aaaa", "bbbb", "cccc"]
|
||||
["aaaa", "bbbb", "cccc"],
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||
"""Test different combinations of text and token limits"""
|
||||
@ -95,15 +101,15 @@ def test_yield_word_chunk_real_world_example():
|
||||
"It tries to maximize chunk size while staying under the specified token limit. "
|
||||
"This behavior is essential for processing large documents efficiently."
|
||||
)
|
||||
|
||||
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
'The yield_word_chunks function splits text',
|
||||
'into chunks based on word boundaries. It',
|
||||
'tries to maximize chunk size while staying',
|
||||
'under the specified token limit. This',
|
||||
'behavior is essential for processing large',
|
||||
'documents efficiently.',
|
||||
"The yield_word_chunks function splits text",
|
||||
"into chunks based on word boundaries. It",
|
||||
"tries to maximize chunk size while staying",
|
||||
"under the specified token limit. This",
|
||||
"behavior is essential for processing large",
|
||||
"documents efficiently.",
|
||||
]
|
||||
|
||||
|
||||
@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text should yield nothing
|
||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
||||
(
|
||||
"Simple paragraph",
|
||||
["Simple paragraph"],
|
||||
), # Single paragraph under token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_spans_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_spans with various inputs"""
|
||||
@ -135,7 +144,7 @@ def test_yield_spans_sentences():
|
||||
sentence1 = "Short sentence one." # ~20 chars
|
||||
sentence2 = "Another short sentence." # ~24 chars
|
||||
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
|
||||
|
||||
|
||||
# Function should now preserve punctuation
|
||||
expected = ["Short sentence one.", "Another short sentence."]
|
||||
assert list(yield_spans(text, max_tokens)) == expected
|
||||
@ -145,8 +154,14 @@ def test_yield_spans_words():
|
||||
"""Test splitting by words when sentences exceed token limit"""
|
||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
||||
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == [
|
||||
"This sentence",
|
||||
"has several",
|
||||
"words and needs",
|
||||
"word-level",
|
||||
"chunking.",
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_complex_document():
|
||||
@ -158,7 +173,7 @@ def test_yield_spans_complex_document():
|
||||
"This one is longer and might need word splitting depending on the limit.\n\n"
|
||||
"Final short paragraph."
|
||||
)
|
||||
|
||||
|
||||
assert list(yield_spans(text, max_tokens)) == [
|
||||
"Paragraph one with a short sentence.",
|
||||
"And another sentence that should be split.",
|
||||
@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
|
||||
"Some are short.",
|
||||
"This one is longer and might need word",
|
||||
"splitting depending on the limit.",
|
||||
"Final short paragraph."
|
||||
"Final short paragraph.",
|
||||
]
|
||||
|
||||
|
||||
@ -175,21 +190,25 @@ def test_yield_spans_very_long_word():
|
||||
"""Test with a word that exceeds the token limit"""
|
||||
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
|
||||
|
||||
|
||||
assert list(yield_spans(long_word, max_tokens)) == [long_word]
|
||||
|
||||
|
||||
def test_yield_spans_with_punctuation():
|
||||
"""Test sentence splitting with various punctuation"""
|
||||
text = "First sentence! Second sentence? Third sentence."
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == [
|
||||
"First sentence!",
|
||||
"Second sentence?",
|
||||
"Third sentence.",
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_edge_cases():
|
||||
"""Test edge cases like empty paragraphs, single character paragraphs"""
|
||||
text = "\n\nA\n\n\n\nB\n\n"
|
||||
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
|
||||
|
||||
|
||||
@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
|
||||
("", []), # Empty text
|
||||
("Short text", ["Short text"]), # Text below token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_chunk_text_basic_behavior(text, expected):
|
||||
"""Test basic behavior of chunk_text with various inputs"""
|
||||
@ -223,63 +242,58 @@ def test_chunk_text_long_text():
|
||||
# Create a long text that will need multiple chunks
|
||||
sentences = [f"This is sentence {i:02}." for i in range(50)]
|
||||
text = " ".join(sentences)
|
||||
|
||||
|
||||
max_tokens = 10 # 10 tokens = ~40 chars
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
||||
] + [
|
||||
'This is sentence 49.'
|
||||
]
|
||||
|
||||
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
||||
] + ["This is sentence 49."]
|
||||
|
||||
|
||||
def test_chunk_text_with_overlap():
|
||||
"""Test chunking with overlap between chunks"""
|
||||
# Create text with distinct parts to test overlap
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
||||
"Part A. Part B. Part C.",
|
||||
"Part C. Part D. Part E.",
|
||||
"Part E.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_zero_overlap():
|
||||
"""Test chunking with zero overlap"""
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
|
||||
# 2 tokens = ~8 chars
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
||||
"Part A. Part B.",
|
||||
"Part C. Part D.",
|
||||
"Part E.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_clean_break():
|
||||
"""Test that chunking attempts to break at sentence boundaries"""
|
||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
|
||||
max_tokens = 5 # Enough for about 2 sentences
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
||||
"First sentence. Second sentence.",
|
||||
"Third sentence. Fourth sentence.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_very_long_sentences():
|
||||
"""Test with very long sentences that exceed the token limit"""
|
||||
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
||||
|
||||
|
||||
max_tokens = 5 # Small limit to force splitting
|
||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||
'This is a very long sentence with many many',
|
||||
'words that will definitely exceed the',
|
||||
'token limit we set for',
|
||||
'this particular test',
|
||||
'case and should be split into multiple',
|
||||
'chunks by the function.',
|
||||
"This is a very long sentence with many many",
|
||||
"words that will definitely exceed the",
|
||||
"token limit we set for",
|
||||
"this particular test",
|
||||
"case and should be split into multiple",
|
||||
"chunks by the function.",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
]
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
||||
|
16
tests/memory/common/test_tokens.py
Normal file
16
tests/memory/common/test_tokens.py
Normal file
@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
],
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
Loading…
x
Reference in New Issue
Block a user