mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-29 14:16:09 +02:00
Compare commits
5 Commits
06eec621c1
...
387bd962e6
Author | SHA1 | Date | |
---|---|---|---|
![]() |
387bd962e6 | ||
![]() |
1276b83ffb | ||
![]() |
510cfdf82f | ||
![]() |
96c2f22b16 | ||
![]() |
8eb6374cac |
@ -156,7 +156,9 @@ services:
|
||||
STATIC_DIR: "/app/static"
|
||||
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
||||
ENABLE_BM25_SEARCH: false
|
||||
secrets: [postgres_password]
|
||||
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||
secrets: [postgres_password, openai_key, anthropic_key]
|
||||
volumes:
|
||||
- ./memory_files:/app/memory_files:rw
|
||||
healthcheck:
|
||||
|
@ -40,8 +40,7 @@ const Search = () => {
|
||||
|
||||
setIsLoading(true)
|
||||
try {
|
||||
console.log(params)
|
||||
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
|
||||
const searchResults = await searchKnowledgeBase(params.query, params.modalities, params.filters, params.config)
|
||||
setResults(searchResults || [])
|
||||
} catch (error) {
|
||||
console.error('Search error:', error)
|
||||
|
@ -10,12 +10,17 @@ type Filter = {
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
type SearchConfig = {
|
||||
previews: boolean
|
||||
useScores: boolean
|
||||
limit: number
|
||||
}
|
||||
|
||||
export interface SearchParams {
|
||||
query: string
|
||||
previews: boolean
|
||||
modalities: string[]
|
||||
filters: Filter
|
||||
limit: number
|
||||
config: SearchConfig
|
||||
}
|
||||
|
||||
interface SearchFormProps {
|
||||
@ -40,6 +45,7 @@ const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
|
||||
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||
const [query, setQuery] = useState('')
|
||||
const [previews, setPreviews] = useState(false)
|
||||
const [useScores, setUseScores] = useState(false)
|
||||
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
||||
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
||||
const [tags, setTags] = useState<Record<string, boolean>>({})
|
||||
@ -68,13 +74,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||
|
||||
onSearch({
|
||||
query,
|
||||
previews,
|
||||
modalities: getSelectedItems(modalities),
|
||||
config: {
|
||||
previews,
|
||||
useScores,
|
||||
limit
|
||||
},
|
||||
filters: {
|
||||
tags: getSelectedItems(tags),
|
||||
...cleanFilters(dynamicFilters)
|
||||
},
|
||||
limit
|
||||
})
|
||||
}
|
||||
|
||||
@ -105,6 +114,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||
Include content previews
|
||||
</label>
|
||||
</div>
|
||||
<div className="search-option">
|
||||
<label>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={useScores}
|
||||
onChange={(e) => setUseScores(e.target.checked)}
|
||||
/>
|
||||
Score results with a LLM before returning
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<SelectableTags
|
||||
title="Modalities"
|
||||
|
@ -1,6 +1,4 @@
|
||||
import { useState, useEffect } from 'react'
|
||||
import ReactMarkdown from 'react-markdown'
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
import { SERVER_URL } from '@/hooks/useAuth'
|
||||
|
||||
export type SearchItem = {
|
||||
@ -74,25 +72,14 @@ export const MarkdownResult = ({ filename, content, chunks, tags, metadata }: Se
|
||||
|
||||
export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
||||
const title = metadata?.title || filename || 'Untitled'
|
||||
const { fetchFile } = useMCP()
|
||||
const [mime_type, setMimeType] = useState<string>()
|
||||
const [content, setContent] = useState<string>()
|
||||
useEffect(() => {
|
||||
const fetchImage = async () => {
|
||||
const files = await fetchFile(filename)
|
||||
const {mime_type, content} = files[0]
|
||||
setMimeType(mime_type)
|
||||
setContent(content)
|
||||
}
|
||||
fetchImage()
|
||||
}, [filename])
|
||||
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{title}</h4>
|
||||
<Tag tags={tags} />
|
||||
<Metadata metadata={metadata} />
|
||||
<div className="image-container">
|
||||
{mime_type && mime_type?.startsWith('image/') && <img src={`data:${mime_type};base64,${content}`} alt={title} className="search-result-image"/>}
|
||||
<img src={`${SERVER_URL}/files/${filename}`} alt={title} className="search-result-image"/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
@ -151,13 +151,12 @@ export const useMCP = () => {
|
||||
return (await mcpCall('get_metadata_schemas'))[0]
|
||||
}, [mcpCall])
|
||||
|
||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
|
||||
const searchKnowledgeBase = useCallback(async (query: string, modalities: string[] = [], filters: Record<string, any> = {}, config: Record<string, any> = {}) => {
|
||||
return await mcpCall('search_knowledge_base', {
|
||||
query,
|
||||
filters,
|
||||
config,
|
||||
modalities,
|
||||
previews,
|
||||
limit,
|
||||
})
|
||||
}, [mcpCall])
|
||||
|
||||
|
@ -6,17 +6,16 @@ import base64
|
||||
import logging
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from mcp.server.fastmcp.resources.base import Resource
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.api.search.search import SearchFilters, search
|
||||
from memory.api.search.search import search
|
||||
from memory.api.search.types import SearchFilters, SearchConfig
|
||||
from memory.common import extract, settings
|
||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||
from memory.common.celery_app import app as celery_app
|
||||
@ -80,22 +79,32 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
filters: dict[str, Any],
|
||||
filters: SearchFilters,
|
||||
config: SearchConfig = SearchConfig(),
|
||||
modalities: set[str] = set(),
|
||||
previews: bool = False,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search user's stored content including emails, documents, articles, books.
|
||||
Use to find specific information the user has saved or received.
|
||||
Combine with search_observations for complete user context.
|
||||
Use the `get_metadata_schemas` tool to get the metadata schema for each collection, from which you can infer the keys for the filters dictionary.
|
||||
|
||||
If you know what kind of data you're looking for, it's worth explicitly filtering by that modality, as this gives better results.
|
||||
|
||||
Args:
|
||||
query: Natural language search query - be descriptive about what you're looking for
|
||||
previews: Include actual content in results - when false only a snippet is returned
|
||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||
filters: Filter by tags, source_ids, etc.
|
||||
limit: Max results (1-100)
|
||||
filters: a dictionary with the following keys:
|
||||
- tags: a list of tags to filter by
|
||||
- source_ids: a list of source ids to filter by
|
||||
- min_size: the minimum size of the content to filter by
|
||||
- max_size: the maximum size of the content to filter by
|
||||
- min_created_at: the minimum created_at date to filter by
|
||||
- max_created_at: the maximum created_at date to filter by
|
||||
config: a dictionary with the following keys:
|
||||
- limit: the maximum number of results to return
|
||||
- previews: whether to include the actual content in the results (up to MAX_PREVIEW_LENGTH characters)
|
||||
- useScores: whether to score the results with a LLM before returning - this results in better results but is slower
|
||||
|
||||
Returns: List of search results with id, score, chunks, content, filename
|
||||
Higher scores (>0.7) indicate strong matches.
|
||||
@ -112,10 +121,9 @@ async def search_knowledge_base(
|
||||
upload_data = extract.extract_text(query, skip_summary=True)
|
||||
results = await search(
|
||||
upload_data,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
limit=limit,
|
||||
filters=search_filters,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return [result.model_dump() for result in results]
|
||||
@ -211,7 +219,7 @@ async def search_observations(
|
||||
tags: list[str] | None = None,
|
||||
observation_types: list[str] | None = None,
|
||||
min_confidences: dict[str, float] = {},
|
||||
limit: int = 10,
|
||||
config: SearchConfig = SearchConfig(),
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search recorded observations about the user.
|
||||
@ -224,7 +232,7 @@ async def search_observations(
|
||||
tags: Filter by tags (must have at least one matching tag)
|
||||
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
||||
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
||||
limit: Max results (1-100)
|
||||
config: SearchConfig
|
||||
|
||||
Returns: List with content, tags, created_at, metadata
|
||||
Results sorted by relevance to your query.
|
||||
@ -246,9 +254,7 @@ async def search_observations(
|
||||
extract.DataChunk(data=[semantic_text]),
|
||||
extract.DataChunk(data=[temporal]),
|
||||
],
|
||||
previews=True,
|
||||
modalities={"semantic", "temporal"},
|
||||
limit=limit,
|
||||
filters=SearchFilters(
|
||||
subject=subject,
|
||||
min_confidences=min_confidences,
|
||||
@ -256,7 +262,7 @@ async def search_observations(
|
||||
observation_types=observation_types,
|
||||
source_ids=filter_observation_source_ids(tags=tags),
|
||||
),
|
||||
timeout=2,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return [
|
||||
@ -351,11 +357,14 @@ def fetch_file(filename: str) -> dict:
|
||||
Text content as string, binary as base64.
|
||||
"""
|
||||
path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/")
|
||||
print("fetching file", path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
|
||||
mime_type = extract.get_mime_type(path)
|
||||
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
|
||||
print("mime_type", mime_type)
|
||||
print("chunks", chunks)
|
||||
|
||||
def serialize_chunk(
|
||||
chunk: extract.DataChunk, data: extract.MulitmodalChunk
|
||||
|
74
src/memory/api/search/scorer.py
Normal file
74
src/memory/api/search/scorer.py
Normal file
@ -0,0 +1,74 @@
|
||||
import asyncio
|
||||
from bs4 import BeautifulSoup
|
||||
from PIL import Image
|
||||
|
||||
from memory.common.db.models.source_item import Chunk
|
||||
from memory.common import llms, settings, tokens
|
||||
|
||||
|
||||
SCORE_CHUNK_SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
||||
|
||||
You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
|
||||
"""
|
||||
|
||||
SCORE_CHUNK_PROMPT = """
|
||||
Here is the query:
|
||||
<query>{query}</query>
|
||||
|
||||
Here is the chunk:
|
||||
<chunk>
|
||||
{chunk}
|
||||
</chunk>
|
||||
|
||||
Please return your score as a number between 0 and 1 formatted as:
|
||||
<score>your score</score>
|
||||
|
||||
Please always return a summary of any images provided.
|
||||
"""
|
||||
|
||||
|
||||
async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||
try:
|
||||
data = chunk.data
|
||||
except Exception as e:
|
||||
print(f"Error getting chunk data: {e}, {type(e)}")
|
||||
return chunk
|
||||
|
||||
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
||||
images = [image for image in data if isinstance(image, Image.Image)]
|
||||
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
llms.call,
|
||||
prompt,
|
||||
settings.RANKER_MODEL,
|
||||
images=images,
|
||||
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error scoring chunk: {e}, {type(e)}")
|
||||
return chunk
|
||||
|
||||
soup = BeautifulSoup(response, "html.parser")
|
||||
if not (score := soup.find("score")):
|
||||
chunk.relevance_score = 0.0
|
||||
else:
|
||||
try:
|
||||
chunk.relevance_score = float(score.text.strip())
|
||||
except ValueError:
|
||||
chunk.relevance_score = 0.0
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
async def rank_chunks(
|
||||
query: str, chunks: list[Chunk], min_score: float = 0
|
||||
) -> list[Chunk]:
|
||||
calls = [score_chunk(query, chunk) for chunk in chunks]
|
||||
scored = await asyncio.gather(*calls)
|
||||
return sorted(
|
||||
[chunk for chunk in scored if chunk.relevance_score >= min_score],
|
||||
key=lambda x: x.relevance_score or 0,
|
||||
reverse=True,
|
||||
)
|
@ -12,11 +12,12 @@ from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.collections import ALL_COLLECTIONS
|
||||
from memory.api.search.embeddings import search_chunks_embeddings
|
||||
from memory.api.search import scorer
|
||||
|
||||
if settings.ENABLE_BM25_SEARCH:
|
||||
from memory.api.search.bm25 import search_bm25_chunks
|
||||
|
||||
from memory.api.search.types import SearchFilters, SearchResult
|
||||
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -40,7 +41,14 @@ async def search_chunks(
|
||||
with make_session() as db:
|
||||
chunks = (
|
||||
db.query(Chunk)
|
||||
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
|
||||
.options(
|
||||
load_only(
|
||||
Chunk.id, # type: ignore
|
||||
Chunk.source_id, # type: ignore
|
||||
Chunk.content, # type: ignore
|
||||
Chunk.file_paths, # type: ignore
|
||||
)
|
||||
)
|
||||
.filter(Chunk.id.in_(all_ids))
|
||||
.all()
|
||||
)
|
||||
@ -49,7 +57,7 @@ async def search_chunks(
|
||||
|
||||
|
||||
async def search_sources(
|
||||
chunks: list[Chunk], previews: Optional[bool] = False
|
||||
chunks: list[Chunk], previews: bool = False
|
||||
) -> list[SearchResult]:
|
||||
by_source = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
@ -65,11 +73,9 @@ async def search_sources(
|
||||
|
||||
async def search(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = {},
|
||||
timeout: int = 2,
|
||||
config: SearchConfig = SearchConfig(),
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
@ -87,8 +93,13 @@ async def search(
|
||||
chunks = await search_chunks(
|
||||
data,
|
||||
allowed_modalities,
|
||||
limit,
|
||||
config.limit,
|
||||
filters,
|
||||
timeout,
|
||||
config.timeout,
|
||||
)
|
||||
return await search_sources(chunks, previews)
|
||||
if settings.ENABLE_SEARCH_SCORING and config.useScores:
|
||||
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||
|
||||
sources = await search_sources(chunks, config.previews)
|
||||
sources.sort(key=lambda x: x.search_score or 0, reverse=True)
|
||||
return sources[: config.limit]
|
||||
|
@ -2,10 +2,9 @@ from datetime import datetime
|
||||
import logging
|
||||
from typing import Optional, TypedDict, NotRequired, cast
|
||||
|
||||
from memory.common.db.models.source_item import SourceItem
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common.db.models import Chunk
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -32,6 +31,7 @@ class SearchResult(BaseModel):
|
||||
tags: list[str] | None = None
|
||||
metadata: dict | None = None
|
||||
created_at: datetime | None = None
|
||||
search_score: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_source_item(
|
||||
@ -41,6 +41,8 @@ class SearchResult(BaseModel):
|
||||
metadata.pop("content", None)
|
||||
chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
|
||||
|
||||
search_score = sum(chunk.relevance_score for chunk in chunks)
|
||||
|
||||
return cls(
|
||||
id=cast(int, source.id),
|
||||
size=cast(int, source.size),
|
||||
@ -56,6 +58,7 @@ class SearchResult(BaseModel):
|
||||
tags=cast(list[str], source.tags),
|
||||
metadata=metadata,
|
||||
created_at=cast(datetime | None, source.inserted_at),
|
||||
search_score=search_score,
|
||||
)
|
||||
|
||||
|
||||
@ -65,3 +68,10 @@ class SearchFilters(TypedDict):
|
||||
min_confidences: NotRequired[dict[str, float]]
|
||||
observation_types: NotRequired[list[str] | None]
|
||||
source_ids: NotRequired[list[int] | None]
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
limit: int = 10
|
||||
timeout: int = 20
|
||||
previews: bool = False
|
||||
useScores: bool = False
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Iterable, Any
|
||||
import re
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common import settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
|
||||
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
||||
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
||||
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
Vector = list[float]
|
||||
@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
|
||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def yield_word_chunks(
|
||||
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
||||
) -> Iterable[str]:
|
||||
@ -36,7 +31,7 @@ def yield_word_chunks(
|
||||
current = ""
|
||||
for word in words:
|
||||
new_chunk = f"{current} {word}".strip()
|
||||
if current and approx_token_count(new_chunk) > max_tokens:
|
||||
if current and tokens.approx_token_count(new_chunk) > max_tokens:
|
||||
yield current
|
||||
current = word
|
||||
else:
|
||||
@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
||||
if not paragraph.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(paragraph) <= max_tokens:
|
||||
if tokens.approx_token_count(paragraph) <= max_tokens:
|
||||
yield paragraph
|
||||
continue
|
||||
|
||||
@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(sentence) <= max_tokens:
|
||||
if tokens.approx_token_count(sentence) <= max_tokens:
|
||||
yield sentence
|
||||
continue
|
||||
|
||||
@ -99,16 +94,16 @@ def chunk_text(
|
||||
if not text:
|
||||
return
|
||||
|
||||
if approx_token_count(text) <= max_tokens:
|
||||
if tokens.approx_token_count(text) <= max_tokens:
|
||||
yield text
|
||||
return
|
||||
|
||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
||||
overlap_chars = overlap * tokens.CHARS_PER_TOKEN
|
||||
current = ""
|
||||
|
||||
for span in yield_spans(text, max_tokens):
|
||||
current = f"{current} {span}".strip()
|
||||
if approx_token_count(current) < max_tokens:
|
||||
if tokens.approx_token_count(current) < max_tokens:
|
||||
continue
|
||||
|
||||
if overlap <= 0:
|
||||
|
@ -156,9 +156,11 @@ class Chunk(Base):
|
||||
collection_name = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
vector: list[float] = []
|
||||
item_metadata: dict[str, Any] = {}
|
||||
images: list[Image.Image] = []
|
||||
relevance_score: float = 0.0
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
|
122
src/memory/common/llms.py
Normal file
122
src/memory/common/llms.py
Normal file
@ -0,0 +1,122 @@
|
||||
import logging
|
||||
import base64
|
||||
import io
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings, tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that creates concise summaries and identifies key topics.
|
||||
"""
|
||||
|
||||
|
||||
def encode_image(image: Image.Image) -> str:
|
||||
"""Encode PIL Image to base64 string."""
|
||||
buffer = io.BytesIO()
|
||||
# Convert to RGB if necessary (for RGBA, etc.)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format="JPEG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def call_openai(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
user_content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
user_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call_anthropic(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
# Prepare the message content
|
||||
content: Any = [{"type": "text", "text": prompt}]
|
||||
if images:
|
||||
# Add images if provided
|
||||
for image in images:
|
||||
encoded_image = encode_image(image)
|
||||
content.append(
|
||||
{ # type: ignore
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": encoded_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
model=model.split("/")[1],
|
||||
messages=[{"role": "user", "content": content}], # type: ignore
|
||||
system=system_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def call(
|
||||
prompt: str,
|
||||
model: str,
|
||||
images: list[Image.Image] = [],
|
||||
system_prompt: str = SYSTEM_PROMPT,
|
||||
) -> str:
|
||||
if model.startswith("anthropic"):
|
||||
return call_anthropic(prompt, model, images, system_prompt)
|
||||
return call_openai(prompt, model, images, system_prompt)
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
||||
else:
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||
|
||||
# Search settings
|
||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from memory.common import settings, chunker
|
||||
from memory.common import settings, tokens, llms
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
|
||||
return {"summary": summary, "tags": tags}
|
||||
|
||||
|
||||
def _call_openai(prompt: str) -> dict[str, Any]:
|
||||
"""Call OpenAI API for summarization."""
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return parse_response(response.choices[0].message.content or "")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _call_anthropic(prompt: str) -> dict[str, Any]:
|
||||
"""Call Anthropic API for summarization."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
return parse_response(response.content[0].text)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
logger.error(response.content[0].text)
|
||||
raise
|
||||
|
||||
|
||||
def truncate(content: str, target_tokens: int) -> str:
|
||||
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
|
||||
if len(content) > target_chars:
|
||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||
return content
|
||||
|
||||
|
||||
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Summarize content to approximately target_tokens length and generate tags.
|
||||
@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
summary, tags = content, []
|
||||
|
||||
# If content is already short enough, just extract tags
|
||||
current_tokens = chunker.approx_token_count(content)
|
||||
current_tokens = tokens.approx_token_count(content)
|
||||
if current_tokens <= target_tokens:
|
||||
logger.info(
|
||||
f"Content already under {target_tokens} tokens, extracting tags only"
|
||||
@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
else:
|
||||
prompt = SUMMARY_PROMPT.format(
|
||||
target_tokens=target_tokens,
|
||||
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
|
||||
target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
|
||||
content=content,
|
||||
)
|
||||
|
||||
if chunker.approx_token_count(prompt) > MAX_TOKENS:
|
||||
if tokens.approx_token_count(prompt) > MAX_TOKENS:
|
||||
logger.warning(
|
||||
f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating"
|
||||
f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
|
||||
)
|
||||
prompt = truncate(prompt, MAX_TOKENS - 20)
|
||||
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||
|
||||
try:
|
||||
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
|
||||
result = _call_anthropic(prompt)
|
||||
else:
|
||||
result = _call_openai(prompt)
|
||||
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
||||
result = parse_response(response)
|
||||
|
||||
summary = result.get("summary", "")
|
||||
tags = result.get("tags", [])
|
||||
@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
||||
traceback.print_exc()
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
tokens = chunker.approx_token_count(summary)
|
||||
if tokens > target_tokens * 1.5:
|
||||
logger.warning(f"Summary too long ({tokens} tokens), truncating")
|
||||
summary = truncate(content, target_tokens)
|
||||
summary_tokens = tokens.approx_token_count(summary)
|
||||
if summary_tokens > target_tokens * 1.5:
|
||||
logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
|
||||
summary = llms.truncate(content, target_tokens)
|
||||
|
||||
return summary, tags
|
||||
|
128
src/memory/common/tokens.py
Normal file
128
src/memory/common/tokens.py
Normal file
@ -0,0 +1,128 @@
|
||||
import logging
|
||||
from PIL import Image
|
||||
import math
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
|
||||
"""
|
||||
Estimate tokens for an image using OpenAI's counting method.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
detail: "high" or "low" detail level
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if detail == "low":
|
||||
return 85
|
||||
|
||||
# For high detail, OpenAI resizes the image to fit within 2048x2048
|
||||
# while maintaining aspect ratio, then counts 512x512 tiles
|
||||
width, height = image.size
|
||||
|
||||
# Resize logic to fit within 2048x2048
|
||||
if width > 2048 or height > 2048:
|
||||
if width > height:
|
||||
height = int(height * 2048 / width)
|
||||
width = 2048
|
||||
else:
|
||||
width = int(width * 2048 / height)
|
||||
height = 2048
|
||||
|
||||
# Further resize so shortest side is 768px
|
||||
if width < height:
|
||||
if width > 768:
|
||||
height = int(height * 768 / width)
|
||||
width = 768
|
||||
else:
|
||||
if height > 768:
|
||||
width = int(width * 768 / height)
|
||||
height = 768
|
||||
|
||||
# Count 512x512 tiles
|
||||
tiles_width = math.ceil(width / 512)
|
||||
tiles_height = math.ceil(height / 512)
|
||||
total_tiles = tiles_width * tiles_height
|
||||
|
||||
# Each tile costs 170 tokens, plus 85 base tokens
|
||||
return total_tiles * 170 + 85
|
||||
|
||||
|
||||
def estimate_anthropic_image_tokens(image: Image.Image) -> int:
|
||||
"""
|
||||
Estimate tokens for an image using Anthropic's counting method.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
width, height = image.size
|
||||
|
||||
# Anthropic's token counting is based on image dimensions
|
||||
# They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
|
||||
# But they also have a base cost per image
|
||||
|
||||
# Rough approximation based on Anthropic's documentation
|
||||
# They count tokens based on the image size after potential resizing
|
||||
total_pixels = width * height
|
||||
|
||||
# Anthropic typically charges around 1.15 tokens per 1000 pixels
|
||||
# with a minimum base cost
|
||||
base_tokens = 100 # Base cost for any image
|
||||
pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
|
||||
|
||||
return base_tokens + pixel_tokens
|
||||
|
||||
|
||||
def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
|
||||
"""
|
||||
Estimate tokens for an image based on the model provider.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
|
||||
detail: Detail level for OpenAI models ("high" or "low")
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if model.startswith("anthropic"):
|
||||
return estimate_anthropic_image_tokens(image)
|
||||
else:
|
||||
return estimate_openai_image_tokens(image, detail)
|
||||
|
||||
|
||||
def estimate_total_tokens(
|
||||
prompt: str, images: list[Image.Image], model: str, detail: str = "high"
|
||||
) -> int:
|
||||
"""
|
||||
Estimate total tokens for a prompt with images.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
images: List of PIL Images
|
||||
model: Model string
|
||||
detail: Detail level for OpenAI models
|
||||
|
||||
Returns:
|
||||
Estimated total token count
|
||||
"""
|
||||
# Estimate text tokens
|
||||
text_tokens = approx_token_count(prompt)
|
||||
|
||||
# Estimate image tokens
|
||||
image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
|
||||
|
||||
return text_tokens + image_tokens
|
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
|
||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
|
||||
from memory.common.tokens import CHARS_PER_TOKEN
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
|
||||
(" ", []), # Just spaces
|
||||
("\n\t ", []), # Whitespace characters
|
||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||
@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
[
|
||||
(
|
||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
||||
["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
|
||||
),
|
||||
(
|
||||
"supercalifragilisticexpialidocious",
|
||||
["supercalifragilisticexpialidocious"],
|
||||
)
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_long_text(text, expected):
|
||||
"""Test chunking with long text that exceeds token limits"""
|
||||
@ -41,7 +42,7 @@ def test_yield_word_chunk_single_long_word():
|
||||
"""Test behavior with a single word longer than the token limit"""
|
||||
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
|
||||
|
||||
|
||||
chunks = list(yield_word_chunks(long_word, max_tokens))
|
||||
# With our changes, this should be a single chunk
|
||||
assert len(chunks) == 1
|
||||
@ -52,8 +53,13 @@ def test_yield_word_chunk_small_token_limit():
|
||||
"""Test with a very small max_tokens value to force chunking"""
|
||||
text = "one two three four five"
|
||||
max_tokens = 1 # Very small to force chunking after each word
|
||||
|
||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
||||
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
"one two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
|
||||
(
|
||||
"word1 word2", # 11 chars with space
|
||||
3, # 12 chars limit
|
||||
["word1 word2"]
|
||||
["word1 word2"],
|
||||
),
|
||||
# Text just over token limit should split
|
||||
(
|
||||
"word1 word2 word3", # 17 chars with spaces
|
||||
4, # 16 chars limit
|
||||
["word1 word2 word3"]
|
||||
["word1 word2 word3"],
|
||||
),
|
||||
# Each word exactly at token limit
|
||||
(
|
||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||
1, # 1 token limit (4 chars)
|
||||
["aaaa", "bbbb", "cccc"]
|
||||
["aaaa", "bbbb", "cccc"],
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||
"""Test different combinations of text and token limits"""
|
||||
@ -95,15 +101,15 @@ def test_yield_word_chunk_real_world_example():
|
||||
"It tries to maximize chunk size while staying under the specified token limit. "
|
||||
"This behavior is essential for processing large documents efficiently."
|
||||
)
|
||||
|
||||
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
'The yield_word_chunks function splits text',
|
||||
'into chunks based on word boundaries. It',
|
||||
'tries to maximize chunk size while staying',
|
||||
'under the specified token limit. This',
|
||||
'behavior is essential for processing large',
|
||||
'documents efficiently.',
|
||||
"The yield_word_chunks function splits text",
|
||||
"into chunks based on word boundaries. It",
|
||||
"tries to maximize chunk size while staying",
|
||||
"under the specified token limit. This",
|
||||
"behavior is essential for processing large",
|
||||
"documents efficiently.",
|
||||
]
|
||||
|
||||
|
||||
@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text should yield nothing
|
||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
||||
(
|
||||
"Simple paragraph",
|
||||
["Simple paragraph"],
|
||||
), # Single paragraph under token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_yield_spans_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_spans with various inputs"""
|
||||
@ -135,7 +144,7 @@ def test_yield_spans_sentences():
|
||||
sentence1 = "Short sentence one." # ~20 chars
|
||||
sentence2 = "Another short sentence." # ~24 chars
|
||||
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
|
||||
|
||||
|
||||
# Function should now preserve punctuation
|
||||
expected = ["Short sentence one.", "Another short sentence."]
|
||||
assert list(yield_spans(text, max_tokens)) == expected
|
||||
@ -145,8 +154,14 @@ def test_yield_spans_words():
|
||||
"""Test splitting by words when sentences exceed token limit"""
|
||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
||||
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == [
|
||||
"This sentence",
|
||||
"has several",
|
||||
"words and needs",
|
||||
"word-level",
|
||||
"chunking.",
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_complex_document():
|
||||
@ -158,7 +173,7 @@ def test_yield_spans_complex_document():
|
||||
"This one is longer and might need word splitting depending on the limit.\n\n"
|
||||
"Final short paragraph."
|
||||
)
|
||||
|
||||
|
||||
assert list(yield_spans(text, max_tokens)) == [
|
||||
"Paragraph one with a short sentence.",
|
||||
"And another sentence that should be split.",
|
||||
@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
|
||||
"Some are short.",
|
||||
"This one is longer and might need word",
|
||||
"splitting depending on the limit.",
|
||||
"Final short paragraph."
|
||||
"Final short paragraph.",
|
||||
]
|
||||
|
||||
|
||||
@ -175,21 +190,25 @@ def test_yield_spans_very_long_word():
|
||||
"""Test with a word that exceeds the token limit"""
|
||||
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
|
||||
|
||||
|
||||
assert list(yield_spans(long_word, max_tokens)) == [long_word]
|
||||
|
||||
|
||||
def test_yield_spans_with_punctuation():
|
||||
"""Test sentence splitting with various punctuation"""
|
||||
text = "First sentence! Second sentence? Third sentence."
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == [
|
||||
"First sentence!",
|
||||
"Second sentence?",
|
||||
"Third sentence.",
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_edge_cases():
|
||||
"""Test edge cases like empty paragraphs, single character paragraphs"""
|
||||
text = "\n\nA\n\n\n\nB\n\n"
|
||||
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
|
||||
|
||||
|
||||
@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
|
||||
("", []), # Empty text
|
||||
("Short text", ["Short text"]), # Text below token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_chunk_text_basic_behavior(text, expected):
|
||||
"""Test basic behavior of chunk_text with various inputs"""
|
||||
@ -223,63 +242,58 @@ def test_chunk_text_long_text():
|
||||
# Create a long text that will need multiple chunks
|
||||
sentences = [f"This is sentence {i:02}." for i in range(50)]
|
||||
text = " ".join(sentences)
|
||||
|
||||
|
||||
max_tokens = 10 # 10 tokens = ~40 chars
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
||||
] + [
|
||||
'This is sentence 49.'
|
||||
]
|
||||
|
||||
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
||||
] + ["This is sentence 49."]
|
||||
|
||||
|
||||
def test_chunk_text_with_overlap():
|
||||
"""Test chunking with overlap between chunks"""
|
||||
# Create text with distinct parts to test overlap
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
||||
"Part A. Part B. Part C.",
|
||||
"Part C. Part D. Part E.",
|
||||
"Part E.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_zero_overlap():
|
||||
"""Test chunking with zero overlap"""
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
|
||||
# 2 tokens = ~8 chars
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
||||
"Part A. Part B.",
|
||||
"Part C. Part D.",
|
||||
"Part E.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_clean_break():
|
||||
"""Test that chunking attempts to break at sentence boundaries"""
|
||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
|
||||
max_tokens = 5 # Enough for about 2 sentences
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
||||
"First sentence. Second sentence.",
|
||||
"Third sentence. Fourth sentence.",
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_very_long_sentences():
|
||||
"""Test with very long sentences that exceed the token limit"""
|
||||
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
||||
|
||||
|
||||
max_tokens = 5 # Small limit to force splitting
|
||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||
'This is a very long sentence with many many',
|
||||
'words that will definitely exceed the',
|
||||
'token limit we set for',
|
||||
'this particular test',
|
||||
'case and should be split into multiple',
|
||||
'chunks by the function.',
|
||||
"This is a very long sentence with many many",
|
||||
"words that will definitely exceed the",
|
||||
"token limit we set for",
|
||||
"this particular test",
|
||||
"case and should be split into multiple",
|
||||
"chunks by the function.",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
]
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
||||
|
16
tests/memory/common/test_tokens.py
Normal file
16
tests/memory/common/test_tokens.py
Normal file
@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
],
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
Loading…
x
Reference in New Issue
Block a user