mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-29 07:34:43 +02:00
second pass in search
This commit is contained in:
parent
06eec621c1
commit
8eb6374cac
@ -9,6 +9,7 @@ const SearchResults = ({ results, isLoading }: { results: any[], isLoading: bool
|
||||
if (isLoading) {
|
||||
return <Loading message="Searching..." />
|
||||
}
|
||||
console.log("results",results)
|
||||
return (
|
||||
<div className="search-results">
|
||||
{results.length > 0 && (
|
||||
|
65
src/memory/api/search/scorer.py
Normal file
65
src/memory/api/search/scorer.py
Normal file
@ -0,0 +1,65 @@
|
||||
import asyncio
|
||||
from bs4 import BeautifulSoup
|
||||
from PIL import Image
|
||||
|
||||
from memory.common.db.models.source_item import Chunk
|
||||
from memory.common import llms, settings, tokens
|
||||
|
||||
|
||||
SCORE_CHUNK_SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
||||
|
||||
You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
|
||||
"""
|
||||
|
||||
SCORE_CHUNK_PROMPT = """
|
||||
Here is the query:
|
||||
<query>{query}</query>
|
||||
|
||||
Here is the chunk:
|
||||
<chunk>
|
||||
{chunk}
|
||||
</chunk>
|
||||
|
||||
Please return your score as a number between 0 and 1 formatted as:
|
||||
<score>your score</score>
|
||||
|
||||
Please always return a summary of any images provided.
|
||||
"""
|
||||
|
||||
|
||||
async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||
data = chunk.data
|
||||
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
||||
images = [image for image in data if isinstance(image, Image.Image)]
|
||||
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||
response = await asyncio.to_thread(
|
||||
llms.call,
|
||||
prompt,
|
||||
settings.RANKER_MODEL,
|
||||
images=images,
|
||||
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
soup = BeautifulSoup(response, "html.parser")
|
||||
if not (score := soup.find("score")):
|
||||
chunk.relevance_score = 0.0
|
||||
else:
|
||||
try:
|
||||
chunk.relevance_score = float(score.text.strip())
|
||||
except ValueError:
|
||||
chunk.relevance_score = 0.0
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
async def rank_chunks(
|
||||
query: str, chunks: list[Chunk], min_score: float = 0
|
||||
) -> list[Chunk]:
|
||||
calls = [score_chunk(query, chunk) for chunk in chunks]
|
||||
scored = await asyncio.gather(*calls)
|
||||
return sorted(
|
||||
[chunk for chunk in scored if chunk.relevance_score >= min_score],
|
||||
key=lambda x: x.relevance_score or 0,
|
||||
reverse=True,
|
||||
)
|
@ -12,6 +12,7 @@ from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.collections import ALL_COLLECTIONS
|
||||
from memory.api.search.embeddings import search_chunks_embeddings
|
||||
from memory.api.search import scorer
|
||||
|
||||
if settings.ENABLE_BM25_SEARCH:
|
||||
from memory.api.search.bm25 import search_bm25_chunks
|
||||
@ -40,7 +41,14 @@ async def search_chunks(
|
||||
with make_session() as db:
|
||||
chunks = (
|
||||
db.query(Chunk)
|
||||
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
|
||||
.options(
|
||||
load_only(
|
||||
Chunk.id, # type: ignore
|
||||
Chunk.source_id, # type: ignore
|
||||
Chunk.content, # type: ignore
|
||||
Chunk.file_paths, # type: ignore
|
||||
)
|
||||
)
|
||||
.filter(Chunk.id.in_(all_ids))
|
||||
.all()
|
||||
)
|
||||
@ -91,4 +99,6 @@ async def search(
|
||||
filters,
|
||||
timeout,
|
||||
)
|
||||
if settings.ENABLE_SEARCH_SCORING:
|
||||
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||
return await search_sources(chunks, previews)
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Iterable, Any
|
||||
import re
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common import settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
|
||||
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
||||
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
||||
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
Vector = list[float]
|
||||
@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
|
||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def yield_word_chunks(
|
||||
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
||||
) -> Iterable[str]:
|
||||
@ -36,7 +31,7 @@ def yield_word_chunks(
|
||||
current = ""
|
||||
for word in words:
|
||||
new_chunk = f"{current} {word}".strip()
|
||||
if current and approx_token_count(new_chunk) > max_tokens:
|
||||
if current and tokens.approx_token_count(new_chunk) > max_tokens:
|
||||
yield current
|
||||
current = word
|
||||
else:
|
||||
@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
||||
if not paragraph.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(paragraph) <= max_tokens:
|
||||
if tokens.approx_token_count(paragraph) <= max_tokens:
|
||||
yield paragraph
|
||||
continue
|
||||
|
||||
@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(sentence) <= max_tokens:
|
||||
if tokens.approx_token_count(sentence) <= max_tokens:
|
||||
yield sentence
|
||||
continue
|
||||
|
||||
@ -99,16 +94,16 @@ def chunk_text(
|
||||
if not text:
|
||||
return
|
||||
|
||||
if approx_token_count(text) <= max_tokens:
|
||||
if tokens.approx_token_count(text) <= max_tokens:
|
||||
yield text
|
||||
return
|
||||
|
||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
||||
overlap_chars = overlap * tokens.CHARS_PER_TOKEN
|
||||
current = ""
|
||||
|
||||
for span in yield_spans(text, max_tokens):
|
||||
current = f"{current} {span}".strip()
|
||||
if approx_token_count(current) < max_tokens:
|
||||
if tokens.approx_token_count(current) < max_tokens:
|
||||
continue
|
||||
|
||||
if overlap <= 0:
|
||||
|
@ -156,9 +156,11 @@ class Chunk(Base):
|
||||
collection_name = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
vector: list[float] = []
|
||||
item_metadata: dict[str, Any] = {}
|
||||
images: list[Image.Image] = []
|
||||
relevance_score: float = 0.0
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
|
122
src/memory/common/llms.py
Normal file
122
src/memory/common/llms.py
Normal file
@ -0,0 +1,122 @@
|
||||
import logging
|
||||
import base64
|
||||
import io
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that creates concise summaries and identifies key topics.
|
||||
"""
|
||||
|
||||
|
||||
def encode_image(image: Image.Image) -> str:
|
||||
"""Encode PIL Image to base64 string."""
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if necessary (for RGBA, etc.)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format="JPEG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def call_openai(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
user_content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
user_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call_anthropic(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
# Prepare the message content
|
||||
content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
# Add images if provided
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
content.append(
|
||||
{ # type: ignore
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": encoded_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[{"role": "user", "content": content}], # type: ignore
|
||||
system=system_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
if model.startswith("anthropic"):
|
||||
return call_anthropic(prompt, model, images, system_prompt)
|
||||
return call_openai(prompt, model, images, system_prompt)
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
||||
else:
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||
|
||||
# Search settings
|
||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from memory.common import settings, chunker
|
||||
from memory.common import settings, tokens, llms
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
|
||||
return {"summary": summary, "tags": tags}
|
||||
|
||||
|
||||
def _call_openai(prompt: str) -> dict[str, Any]:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return parse_response(response.choices[0].message.content or "")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _call_anthropic(prompt: str) -> dict[str, Any]:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return parse_response(response.content[0].text)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
logger.error(response.content[0].text)
|
||||
raise
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
||||
|
||||
|
||||
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Summarize content to approximately target_tokens length and generate tags.
|
||||
@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
summary, tags = content, []
|
||||
|
||||
# If content is already short enough, just extract tags
|
||||
current_tokens = chunker.approx_token_count(content)
|
||||
current_tokens = tokens.approx_token_count(content)
|
||||
if current_tokens <= target_tokens:
|
||||
logger.info(
|
||||
f"Content already under {target_tokens} tokens, extracting tags only"
|
||||
@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
else:
|
||||
prompt = SUMMARY_PROMPT.format(
|
||||
target_tokens=target_tokens,
|
||||
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
|
||||
target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
|
||||
content=content,
|
||||
)
|
||||
|
||||
if chunker.approx_token_count(prompt) > MAX_TOKENS:
|
||||
if tokens.approx_token_count(prompt) > MAX_TOKENS:
|
||||
logger.warning(
|
||||
f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating"
|
||||
f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
|
||||
)
|
||||
prompt = truncate(prompt, MAX_TOKENS - 20)
|
||||
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||
|
||||
try:
|
||||
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
|
||||
result = _call_anthropic(prompt)
|
||||
else:
|
||||
result = _call_openai(prompt)
|
||||
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
||||
result = parse_response(response)
|
||||
|
||||
summary = result.get("summary", "")
|
||||
tags = result.get("tags", [])
|
||||
@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
traceback.print_exc()
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
tokens = chunker.approx_token_count(summary)
|
||||
if tokens > target_tokens * 1.5:
|
||||
logger.warning(f"Summary too long ({tokens} tokens), truncating")
|
||||
summary = truncate(content, target_tokens)
|
||||
summary_tokens = tokens.approx_token_count(summary)
|
||||
if summary_tokens > target_tokens * 1.5:
|
||||
logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
|
||||
summary = llms.truncate(content, target_tokens)
|
||||
|
||||
return summary, tags
|
||||
|
128
src/memory/common/tokens.py
Normal file
128
src/memory/common/tokens.py
Normal file
@ -0,0 +1,128 @@
|
||||
import logging
|
||||
from PIL import Image
|
||||
import math
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
|
||||
"""
|
||||
Estimate tokens for an image using OpenAI's counting method.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
detail: "high" or "low" detail level
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if detail == "low":
|
||||
return 85
|
||||
|
||||
# For high detail, OpenAI resizes the image to fit within 2048x2048
|
||||
# while maintaining aspect ratio, then counts 512x512 tiles
|
||||
width, height = image.size
|
||||
|
||||
# Resize logic to fit within 2048x2048
|
||||
if width > 2048 or height > 2048:
|
||||
if width > height:
|
||||
height = int(height * 2048 / width)
|
||||
width = 2048
|
||||
else:
|
||||
width = int(width * 2048 / height)
|
||||
height = 2048
|
||||
|
||||
# Further resize so shortest side is 768px
|
||||
if width < height:
|
||||
if width > 768:
|
||||
height = int(height * 768 / width)
|
||||
width = 768
|
||||
else:
|
||||
if height > 768:
|
||||
width = int(width * 768 / height)
|
||||
height = 768
|
||||
|
||||
# Count 512x512 tiles
|
||||
tiles_width = math.ceil(width / 512)
|
||||
tiles_height = math.ceil(height / 512)
|
||||
total_tiles = tiles_width * tiles_height
|
||||
|
||||
# Each tile costs 170 tokens, plus 85 base tokens
|
||||
return total_tiles * 170 + 85
|
||||
|
||||
|
||||
def estimate_anthropic_image_tokens(image: Image.Image) -> int:
|
||||
"""
|
||||
Estimate tokens for an image using Anthropic's counting method.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
width, height = image.size
|
||||
|
||||
# Anthropic's token counting is based on image dimensions
|
||||
# They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
|
||||
# But they also have a base cost per image
|
||||
|
||||
# Rough approximation based on Anthropic's documentation
|
||||
# They count tokens based on the image size after potential resizing
|
||||
total_pixels = width * height
|
||||
|
||||
# Anthropic typically charges around 1.15 tokens per 1000 pixels
|
||||
# with a minimum base cost
|
||||
base_tokens = 100 # Base cost for any image
|
||||
pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
|
||||
|
||||
return base_tokens + pixel_tokens
|
||||
|
||||
|
||||
def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
|
||||
"""
|
||||
Estimate tokens for an image based on the model provider.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
|
||||
detail: Detail level for OpenAI models ("high" or "low")
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if model.startswith("anthropic"):
|
||||
return estimate_anthropic_image_tokens(image)
|
||||
else:
|
||||
return estimate_openai_image_tokens(image, detail)
|
||||
|
||||
|
||||
def estimate_total_tokens(
|
||||
prompt: str, images: list[Image.Image], model: str, detail: str = "high"
|
||||
) -> int:
|
||||
"""
|
||||
Estimate total tokens for a prompt with images.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
images: List of PIL Images
|
||||
model: Model string
|
||||
detail: Detail level for OpenAI models
|
||||
|
||||
Returns:
|
||||
Estimated total token count
|
||||
"""
|
||||
# Estimate text tokens
|
||||
text_tokens = approx_token_count(prompt)
|
||||
|
||||
# Estimate image tokens
|
||||
image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
|
||||
|
||||
return text_tokens + image_tokens
|
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
|
||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
|
||||
from memory.common.tokens import CHARS_PER_TOKEN
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
|
||||
(" ", []), # Just spaces
|
||||
("\n\t ", []), # Whitespace characters
|
||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||
@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
[
|
||||
(
|
||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
||||
["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
|
||||
),
|
||||
(
|
||||
"supercalifragilisticexpialidocious",
|
||||
["supercalifragilisticexpialidocious"],
|
||||
)
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_long_text(text, expected):
|
||||
"""Test chunking with long text that exceeds token limits"""
|
||||
@ -53,7 +54,12 @@ def test_yield_word_chunk_small_token_limit():
|
||||
text = "one two three four five"
|
||||
max_tokens = 1 # Very small to force chunking after each word
|
||||
|
||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
"one two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
|
||||
(
|
||||
"word1 word2", # 11 chars with space
|
||||
3, # 12 chars limit
|
||||
["word1 word2"]
|
||||
["word1 word2"],
|
||||
),
|
||||
# Text just over token limit should split
|
||||
(
|
||||
"word1 word2 word3", # 17 chars with spaces
|
||||
4, # 16 chars limit
|
||||
["word1 word2 word3"]
|
||||
["word1 word2 word3"],
|
||||
),
|
||||
# Each word exactly at token limit
|
||||
(
|
||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||
1, # 1 token limit (4 chars)
|
||||
["aaaa", "bbbb", "cccc"]
|
||||
["aaaa", "bbbb", "cccc"],
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||
"""Test different combinations of text and token limits"""
|
||||
@ -98,12 +104,12 @@ def test_yield_word_chunk_real_world_example():
|
||||
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
'The yield_word_chunks function splits text',
|
||||
'into chunks based on word boundaries. It',
|
||||
'tries to maximize chunk size while staying',
|
||||
'under the specified token limit. This',
|
||||
'behavior is essential for processing large',
|
||||
'documents efficiently.',
|
||||
"The yield_word_chunks function splits text",
|
||||
"into chunks based on word boundaries. It",
|
||||
"tries to maximize chunk size while staying",
|
||||
"under the specified token limit. This",
|
||||
"behavior is essential for processing large",
|
||||
"documents efficiently.",
|
||||
]
|
||||
|
||||
|
||||
@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text should yield nothing
|
||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
||||
(
|
||||
"Simple paragraph",
|
||||
["Simple paragraph"],
|
||||
), # Single paragraph under token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_spans_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_spans with various inputs"""
|
||||
@ -146,7 +155,13 @@ def test_yield_spans_words():
|
||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == [
|
||||
"This sentence",
|
||||
"has several",
|
||||
"words and needs",
|
||||
"word-level",
|
||||
"chunking.",
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_complex_document():
|
||||
@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
|
||||
"Some are short.",
|
||||
"This one is longer and might need word",
|
||||
"splitting depending on the limit.",
|
||||
"Final short paragraph."
|
||||
"Final short paragraph.",
|
||||
]
|
||||
|
||||
|
||||
@ -183,7 +198,11 @@ def test_yield_spans_with_punctuation():
|
||||
"""Test sentence splitting with various punctuation"""
|
||||
text = "First sentence! Second sentence? Third sentence."
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
||||
assert list(yield_spans(text, max_tokens=10)) == [
|
||||
"First sentence!",
|
||||
"Second sentence?",
|
||||
"Third sentence.",
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_edge_cases():
|
||||
@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
|
||||
("", []), # Empty text
|
||||
("Short text", ["Short text"]), # Text below token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_chunk_text_basic_behavior(text, expected):
|
||||
"""Test basic behavior of chunk_text with various inputs"""
|
||||
@ -226,10 +245,8 @@ def test_chunk_text_long_text():
|
||||
|
||||
max_tokens = 10 # 10 tokens = ~40 chars
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
||||
] + [
|
||||
'This is sentence 49.'
|
||||
]
|
||||
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
||||
] + ["This is sentence 49."]
|
||||
|
||||
|
||||
def test_chunk_text_with_overlap():
|
||||
@ -237,7 +254,11 @@ def test_chunk_text_with_overlap():
|
||||
# Create text with distinct parts to test overlap
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
||||
"Part A. Part B. Part C.",
|
||||
"Part C. Part D. Part E.",
|
||||
"Part E.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_zero_overlap():
|
||||
@ -245,7 +266,11 @@ def test_chunk_text_zero_overlap():
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
# 2 tokens = ~8 chars
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
||||
"Part A. Part B.",
|
||||
"Part C. Part D.",
|
||||
"Part E.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_clean_break():
|
||||
@ -253,7 +278,10 @@ def test_chunk_text_clean_break():
|
||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
max_tokens = 5 # Enough for about 2 sentences
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
||||
"First sentence. Second sentence.",
|
||||
"Third sentence. Fourth sentence.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_very_long_sentences():
|
||||
@ -262,24 +290,10 @@ def test_chunk_text_very_long_sentences():
|
||||
|
||||
max_tokens = 5 # Small limit to force splitting
|
||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||
'This is a very long sentence with many many',
|
||||
'words that will definitely exceed the',
|
||||
'token limit we set for',
|
||||
'this particular test',
|
||||
'case and should be split into multiple',
|
||||
'chunks by the function.',
|
||||
"This is a very long sentence with many many",
|
||||
"words that will definitely exceed the",
|
||||
"token limit we set for",
|
||||
"this particular test",
|
||||
"case and should be split into multiple",
|
||||
"chunks by the function.",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
]
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
||||
|
16
tests/memory/common/test_tokens.py
Normal file
16
tests/memory/common/test_tokens.py
Normal file
@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
],
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
Loading…
x
Reference in New Issue
Block a user