mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-30 06:36:07 +02:00
Compare commits
5 Commits
06eec621c1
...
387bd962e6
Author | SHA1 | Date | |
---|---|---|---|
![]() |
387bd962e6 | ||
![]() |
1276b83ffb | ||
![]() |
510cfdf82f | ||
![]() |
96c2f22b16 | ||
![]() |
8eb6374cac |
@ -156,7 +156,9 @@ services:
|
|||||||
STATIC_DIR: "/app/static"
|
STATIC_DIR: "/app/static"
|
||||||
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
||||||
ENABLE_BM25_SEARCH: false
|
ENABLE_BM25_SEARCH: false
|
||||||
secrets: [postgres_password]
|
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||||
|
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||||
|
secrets: [postgres_password, openai_key, anthropic_key]
|
||||||
volumes:
|
volumes:
|
||||||
- ./memory_files:/app/memory_files:rw
|
- ./memory_files:/app/memory_files:rw
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
@ -40,8 +40,7 @@ const Search = () => {
|
|||||||
|
|
||||||
setIsLoading(true)
|
setIsLoading(true)
|
||||||
try {
|
try {
|
||||||
console.log(params)
|
const searchResults = await searchKnowledgeBase(params.query, params.modalities, params.filters, params.config)
|
||||||
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
|
|
||||||
setResults(searchResults || [])
|
setResults(searchResults || [])
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Search error:', error)
|
console.error('Search error:', error)
|
||||||
|
@ -10,12 +10,17 @@ type Filter = {
|
|||||||
[key: string]: any
|
[key: string]: any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SearchConfig = {
|
||||||
|
previews: boolean
|
||||||
|
useScores: boolean
|
||||||
|
limit: number
|
||||||
|
}
|
||||||
|
|
||||||
export interface SearchParams {
|
export interface SearchParams {
|
||||||
query: string
|
query: string
|
||||||
previews: boolean
|
|
||||||
modalities: string[]
|
modalities: string[]
|
||||||
filters: Filter
|
filters: Filter
|
||||||
limit: number
|
config: SearchConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
interface SearchFormProps {
|
interface SearchFormProps {
|
||||||
@ -40,6 +45,7 @@ const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
|
|||||||
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||||
const [query, setQuery] = useState('')
|
const [query, setQuery] = useState('')
|
||||||
const [previews, setPreviews] = useState(false)
|
const [previews, setPreviews] = useState(false)
|
||||||
|
const [useScores, setUseScores] = useState(false)
|
||||||
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
||||||
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
||||||
const [tags, setTags] = useState<Record<string, boolean>>({})
|
const [tags, setTags] = useState<Record<string, boolean>>({})
|
||||||
@ -68,13 +74,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
|||||||
|
|
||||||
onSearch({
|
onSearch({
|
||||||
query,
|
query,
|
||||||
previews,
|
|
||||||
modalities: getSelectedItems(modalities),
|
modalities: getSelectedItems(modalities),
|
||||||
|
config: {
|
||||||
|
previews,
|
||||||
|
useScores,
|
||||||
|
limit
|
||||||
|
},
|
||||||
filters: {
|
filters: {
|
||||||
tags: getSelectedItems(tags),
|
tags: getSelectedItems(tags),
|
||||||
...cleanFilters(dynamicFilters)
|
...cleanFilters(dynamicFilters)
|
||||||
},
|
},
|
||||||
limit
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,6 +114,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
|||||||
Include content previews
|
Include content previews
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
<div className="search-option">
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={useScores}
|
||||||
|
onChange={(e) => setUseScores(e.target.checked)}
|
||||||
|
/>
|
||||||
|
Score results with a LLM before returning
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
<SelectableTags
|
<SelectableTags
|
||||||
title="Modalities"
|
title="Modalities"
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import { useState, useEffect } from 'react'
|
|
||||||
import ReactMarkdown from 'react-markdown'
|
import ReactMarkdown from 'react-markdown'
|
||||||
import { useMCP } from '@/hooks/useMCP'
|
|
||||||
import { SERVER_URL } from '@/hooks/useAuth'
|
import { SERVER_URL } from '@/hooks/useAuth'
|
||||||
|
|
||||||
export type SearchItem = {
|
export type SearchItem = {
|
||||||
@ -74,25 +72,14 @@ export const MarkdownResult = ({ filename, content, chunks, tags, metadata }: Se
|
|||||||
|
|
||||||
export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
||||||
const title = metadata?.title || filename || 'Untitled'
|
const title = metadata?.title || filename || 'Untitled'
|
||||||
const { fetchFile } = useMCP()
|
|
||||||
const [mime_type, setMimeType] = useState<string>()
|
|
||||||
const [content, setContent] = useState<string>()
|
|
||||||
useEffect(() => {
|
|
||||||
const fetchImage = async () => {
|
|
||||||
const files = await fetchFile(filename)
|
|
||||||
const {mime_type, content} = files[0]
|
|
||||||
setMimeType(mime_type)
|
|
||||||
setContent(content)
|
|
||||||
}
|
|
||||||
fetchImage()
|
|
||||||
}, [filename])
|
|
||||||
return (
|
return (
|
||||||
<div className="search-result-card">
|
<div className="search-result-card">
|
||||||
<h4>{title}</h4>
|
<h4>{title}</h4>
|
||||||
<Tag tags={tags} />
|
<Tag tags={tags} />
|
||||||
<Metadata metadata={metadata} />
|
<Metadata metadata={metadata} />
|
||||||
<div className="image-container">
|
<div className="image-container">
|
||||||
{mime_type && mime_type?.startsWith('image/') && <img src={`data:${mime_type};base64,${content}`} alt={title} className="search-result-image"/>}
|
<img src={`${SERVER_URL}/files/${filename}`} alt={title} className="search-result-image"/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
@ -151,13 +151,12 @@ export const useMCP = () => {
|
|||||||
return (await mcpCall('get_metadata_schemas'))[0]
|
return (await mcpCall('get_metadata_schemas'))[0]
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
|
const searchKnowledgeBase = useCallback(async (query: string, modalities: string[] = [], filters: Record<string, any> = {}, config: Record<string, any> = {}) => {
|
||||||
return await mcpCall('search_knowledge_base', {
|
return await mcpCall('search_knowledge_base', {
|
||||||
query,
|
query,
|
||||||
filters,
|
filters,
|
||||||
|
config,
|
||||||
modalities,
|
modalities,
|
||||||
previews,
|
|
||||||
limit,
|
|
||||||
})
|
})
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
|
@ -6,17 +6,16 @@ import base64
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
from mcp.server.fastmcp.resources.base import Resource
|
|
||||||
|
|
||||||
from memory.api.MCP.tools import mcp
|
from memory.api.MCP.tools import mcp
|
||||||
from memory.api.search.search import SearchFilters, search
|
from memory.api.search.search import search
|
||||||
|
from memory.api.search.types import SearchFilters, SearchConfig
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||||
from memory.common.celery_app import app as celery_app
|
from memory.common.celery_app import app as celery_app
|
||||||
@ -80,22 +79,32 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
|||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def search_knowledge_base(
|
async def search_knowledge_base(
|
||||||
query: str,
|
query: str,
|
||||||
filters: dict[str, Any],
|
filters: SearchFilters,
|
||||||
|
config: SearchConfig = SearchConfig(),
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
previews: bool = False,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Search user's stored content including emails, documents, articles, books.
|
Search user's stored content including emails, documents, articles, books.
|
||||||
Use to find specific information the user has saved or received.
|
Use to find specific information the user has saved or received.
|
||||||
Combine with search_observations for complete user context.
|
Combine with search_observations for complete user context.
|
||||||
|
Use the `get_metadata_schemas` tool to get the metadata schema for each collection, from which you can infer the keys for the filters dictionary.
|
||||||
|
|
||||||
|
If you know what kind of data you're looking for, it's worth explicitly filtering by that modality, as this gives better results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Natural language search query - be descriptive about what you're looking for
|
query: Natural language search query - be descriptive about what you're looking for
|
||||||
previews: Include actual content in results - when false only a snippet is returned
|
|
||||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||||
filters: Filter by tags, source_ids, etc.
|
filters: a dictionary with the following keys:
|
||||||
limit: Max results (1-100)
|
- tags: a list of tags to filter by
|
||||||
|
- source_ids: a list of source ids to filter by
|
||||||
|
- min_size: the minimum size of the content to filter by
|
||||||
|
- max_size: the maximum size of the content to filter by
|
||||||
|
- min_created_at: the minimum created_at date to filter by
|
||||||
|
- max_created_at: the maximum created_at date to filter by
|
||||||
|
config: a dictionary with the following keys:
|
||||||
|
- limit: the maximum number of results to return
|
||||||
|
- previews: whether to include the actual content in the results (up to MAX_PREVIEW_LENGTH characters)
|
||||||
|
- useScores: whether to score the results with a LLM before returning - this results in better results but is slower
|
||||||
|
|
||||||
Returns: List of search results with id, score, chunks, content, filename
|
Returns: List of search results with id, score, chunks, content, filename
|
||||||
Higher scores (>0.7) indicate strong matches.
|
Higher scores (>0.7) indicate strong matches.
|
||||||
@ -112,10 +121,9 @@ async def search_knowledge_base(
|
|||||||
upload_data = extract.extract_text(query, skip_summary=True)
|
upload_data = extract.extract_text(query, skip_summary=True)
|
||||||
results = await search(
|
results = await search(
|
||||||
upload_data,
|
upload_data,
|
||||||
previews=previews,
|
|
||||||
modalities=modalities,
|
modalities=modalities,
|
||||||
limit=limit,
|
|
||||||
filters=search_filters,
|
filters=search_filters,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result.model_dump() for result in results]
|
return [result.model_dump() for result in results]
|
||||||
@ -211,7 +219,7 @@ async def search_observations(
|
|||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
observation_types: list[str] | None = None,
|
observation_types: list[str] | None = None,
|
||||||
min_confidences: dict[str, float] = {},
|
min_confidences: dict[str, float] = {},
|
||||||
limit: int = 10,
|
config: SearchConfig = SearchConfig(),
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Search recorded observations about the user.
|
Search recorded observations about the user.
|
||||||
@ -224,7 +232,7 @@ async def search_observations(
|
|||||||
tags: Filter by tags (must have at least one matching tag)
|
tags: Filter by tags (must have at least one matching tag)
|
||||||
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
||||||
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
||||||
limit: Max results (1-100)
|
config: SearchConfig
|
||||||
|
|
||||||
Returns: List with content, tags, created_at, metadata
|
Returns: List with content, tags, created_at, metadata
|
||||||
Results sorted by relevance to your query.
|
Results sorted by relevance to your query.
|
||||||
@ -246,9 +254,7 @@ async def search_observations(
|
|||||||
extract.DataChunk(data=[semantic_text]),
|
extract.DataChunk(data=[semantic_text]),
|
||||||
extract.DataChunk(data=[temporal]),
|
extract.DataChunk(data=[temporal]),
|
||||||
],
|
],
|
||||||
previews=True,
|
|
||||||
modalities={"semantic", "temporal"},
|
modalities={"semantic", "temporal"},
|
||||||
limit=limit,
|
|
||||||
filters=SearchFilters(
|
filters=SearchFilters(
|
||||||
subject=subject,
|
subject=subject,
|
||||||
min_confidences=min_confidences,
|
min_confidences=min_confidences,
|
||||||
@ -256,7 +262,7 @@ async def search_observations(
|
|||||||
observation_types=observation_types,
|
observation_types=observation_types,
|
||||||
source_ids=filter_observation_source_ids(tags=tags),
|
source_ids=filter_observation_source_ids(tags=tags),
|
||||||
),
|
),
|
||||||
timeout=2,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@ -351,11 +357,14 @@ def fetch_file(filename: str) -> dict:
|
|||||||
Text content as string, binary as base64.
|
Text content as string, binary as base64.
|
||||||
"""
|
"""
|
||||||
path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/")
|
path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/")
|
||||||
|
print("fetching file", path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise FileNotFoundError(f"File not found: {filename}")
|
raise FileNotFoundError(f"File not found: {filename}")
|
||||||
|
|
||||||
mime_type = extract.get_mime_type(path)
|
mime_type = extract.get_mime_type(path)
|
||||||
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
|
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
|
||||||
|
print("mime_type", mime_type)
|
||||||
|
print("chunks", chunks)
|
||||||
|
|
||||||
def serialize_chunk(
|
def serialize_chunk(
|
||||||
chunk: extract.DataChunk, data: extract.MulitmodalChunk
|
chunk: extract.DataChunk, data: extract.MulitmodalChunk
|
||||||
|
74
src/memory/api/search/scorer.py
Normal file
74
src/memory/api/search/scorer.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import asyncio
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common.db.models.source_item import Chunk
|
||||||
|
from memory.common import llms, settings, tokens
|
||||||
|
|
||||||
|
|
||||||
|
SCORE_CHUNK_SYSTEM_PROMPT = """
|
||||||
|
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
||||||
|
|
||||||
|
You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCORE_CHUNK_PROMPT = """
|
||||||
|
Here is the query:
|
||||||
|
<query>{query}</query>
|
||||||
|
|
||||||
|
Here is the chunk:
|
||||||
|
<chunk>
|
||||||
|
{chunk}
|
||||||
|
</chunk>
|
||||||
|
|
||||||
|
Please return your score as a number between 0 and 1 formatted as:
|
||||||
|
<score>your score</score>
|
||||||
|
|
||||||
|
Please always return a summary of any images provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||||
|
try:
|
||||||
|
data = chunk.data
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting chunk data: {e}, {type(e)}")
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
||||||
|
images = [image for image in data if isinstance(image, Image.Image)]
|
||||||
|
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||||
|
try:
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
llms.call,
|
||||||
|
prompt,
|
||||||
|
settings.RANKER_MODEL,
|
||||||
|
images=images,
|
||||||
|
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error scoring chunk: {e}, {type(e)}")
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
soup = BeautifulSoup(response, "html.parser")
|
||||||
|
if not (score := soup.find("score")):
|
||||||
|
chunk.relevance_score = 0.0
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
chunk.relevance_score = float(score.text.strip())
|
||||||
|
except ValueError:
|
||||||
|
chunk.relevance_score = 0.0
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def rank_chunks(
|
||||||
|
query: str, chunks: list[Chunk], min_score: float = 0
|
||||||
|
) -> list[Chunk]:
|
||||||
|
calls = [score_chunk(query, chunk) for chunk in chunks]
|
||||||
|
scored = await asyncio.gather(*calls)
|
||||||
|
return sorted(
|
||||||
|
[chunk for chunk in scored if chunk.relevance_score >= min_score],
|
||||||
|
key=lambda x: x.relevance_score or 0,
|
||||||
|
reverse=True,
|
||||||
|
)
|
@ -12,11 +12,12 @@ from memory.common.db.connection import make_session
|
|||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.common.collections import ALL_COLLECTIONS
|
from memory.common.collections import ALL_COLLECTIONS
|
||||||
from memory.api.search.embeddings import search_chunks_embeddings
|
from memory.api.search.embeddings import search_chunks_embeddings
|
||||||
|
from memory.api.search import scorer
|
||||||
|
|
||||||
if settings.ENABLE_BM25_SEARCH:
|
if settings.ENABLE_BM25_SEARCH:
|
||||||
from memory.api.search.bm25 import search_bm25_chunks
|
from memory.api.search.bm25 import search_bm25_chunks
|
||||||
|
|
||||||
from memory.api.search.types import SearchFilters, SearchResult
|
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -40,7 +41,14 @@ async def search_chunks(
|
|||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
chunks = (
|
chunks = (
|
||||||
db.query(Chunk)
|
db.query(Chunk)
|
||||||
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
|
.options(
|
||||||
|
load_only(
|
||||||
|
Chunk.id, # type: ignore
|
||||||
|
Chunk.source_id, # type: ignore
|
||||||
|
Chunk.content, # type: ignore
|
||||||
|
Chunk.file_paths, # type: ignore
|
||||||
|
)
|
||||||
|
)
|
||||||
.filter(Chunk.id.in_(all_ids))
|
.filter(Chunk.id.in_(all_ids))
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
@ -49,7 +57,7 @@ async def search_chunks(
|
|||||||
|
|
||||||
|
|
||||||
async def search_sources(
|
async def search_sources(
|
||||||
chunks: list[Chunk], previews: Optional[bool] = False
|
chunks: list[Chunk], previews: bool = False
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
by_source = defaultdict(list)
|
by_source = defaultdict(list)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
@ -65,11 +73,9 @@ async def search_sources(
|
|||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
previews: Optional[bool] = False,
|
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
limit: int = 10,
|
|
||||||
filters: SearchFilters = {},
|
filters: SearchFilters = {},
|
||||||
timeout: int = 2,
|
config: SearchConfig = SearchConfig(),
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search across knowledge base using text query and optional files.
|
Search across knowledge base using text query and optional files.
|
||||||
@ -87,8 +93,13 @@ async def search(
|
|||||||
chunks = await search_chunks(
|
chunks = await search_chunks(
|
||||||
data,
|
data,
|
||||||
allowed_modalities,
|
allowed_modalities,
|
||||||
limit,
|
config.limit,
|
||||||
filters,
|
filters,
|
||||||
timeout,
|
config.timeout,
|
||||||
)
|
)
|
||||||
return await search_sources(chunks, previews)
|
if settings.ENABLE_SEARCH_SCORING and config.useScores:
|
||||||
|
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||||
|
|
||||||
|
sources = await search_sources(chunks, config.previews)
|
||||||
|
sources.sort(key=lambda x: x.search_score or 0, reverse=True)
|
||||||
|
return sources[: config.limit]
|
||||||
|
@ -2,10 +2,9 @@ from datetime import datetime
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, TypedDict, NotRequired, cast
|
from typing import Optional, TypedDict, NotRequired, cast
|
||||||
|
|
||||||
from memory.common.db.models.source_item import SourceItem
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from memory.common.db.models import Chunk
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -32,6 +31,7 @@ class SearchResult(BaseModel):
|
|||||||
tags: list[str] | None = None
|
tags: list[str] | None = None
|
||||||
metadata: dict | None = None
|
metadata: dict | None = None
|
||||||
created_at: datetime | None = None
|
created_at: datetime | None = None
|
||||||
|
search_score: float | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_source_item(
|
def from_source_item(
|
||||||
@ -41,6 +41,8 @@ class SearchResult(BaseModel):
|
|||||||
metadata.pop("content", None)
|
metadata.pop("content", None)
|
||||||
chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
|
chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
|
||||||
|
|
||||||
|
search_score = sum(chunk.relevance_score for chunk in chunks)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
id=cast(int, source.id),
|
id=cast(int, source.id),
|
||||||
size=cast(int, source.size),
|
size=cast(int, source.size),
|
||||||
@ -56,6 +58,7 @@ class SearchResult(BaseModel):
|
|||||||
tags=cast(list[str], source.tags),
|
tags=cast(list[str], source.tags),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
created_at=cast(datetime | None, source.inserted_at),
|
created_at=cast(datetime | None, source.inserted_at),
|
||||||
|
search_score=search_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -65,3 +68,10 @@ class SearchFilters(TypedDict):
|
|||||||
min_confidences: NotRequired[dict[str, float]]
|
min_confidences: NotRequired[dict[str, float]]
|
||||||
observation_types: NotRequired[list[str] | None]
|
observation_types: NotRequired[list[str] | None]
|
||||||
source_ids: NotRequired[list[int] | None]
|
source_ids: NotRequired[list[int] | None]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchConfig(BaseModel):
|
||||||
|
limit: int = 10
|
||||||
|
timeout: int = 20
|
||||||
|
previews: bool = False
|
||||||
|
useScores: bool = False
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
from typing import Iterable, Any
|
from typing import Iterable, Any
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings, tokens
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
|
|||||||
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
||||||
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
||||||
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
||||||
CHARS_PER_TOKEN = 4
|
|
||||||
|
|
||||||
|
|
||||||
Vector = list[float]
|
Vector = list[float]
|
||||||
@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
|
|||||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||||
|
|
||||||
|
|
||||||
def approx_token_count(s: str) -> int:
|
|
||||||
return len(s) // CHARS_PER_TOKEN
|
|
||||||
|
|
||||||
|
|
||||||
def yield_word_chunks(
|
def yield_word_chunks(
|
||||||
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
@ -36,7 +31,7 @@ def yield_word_chunks(
|
|||||||
current = ""
|
current = ""
|
||||||
for word in words:
|
for word in words:
|
||||||
new_chunk = f"{current} {word}".strip()
|
new_chunk = f"{current} {word}".strip()
|
||||||
if current and approx_token_count(new_chunk) > max_tokens:
|
if current and tokens.approx_token_count(new_chunk) > max_tokens:
|
||||||
yield current
|
yield current
|
||||||
current = word
|
current = word
|
||||||
else:
|
else:
|
||||||
@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
|||||||
if not paragraph.strip():
|
if not paragraph.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if approx_token_count(paragraph) <= max_tokens:
|
if tokens.approx_token_count(paragraph) <= max_tokens:
|
||||||
yield paragraph
|
yield paragraph
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
|||||||
if not sentence.strip():
|
if not sentence.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if approx_token_count(sentence) <= max_tokens:
|
if tokens.approx_token_count(sentence) <= max_tokens:
|
||||||
yield sentence
|
yield sentence
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -99,16 +94,16 @@ def chunk_text(
|
|||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
if approx_token_count(text) <= max_tokens:
|
if tokens.approx_token_count(text) <= max_tokens:
|
||||||
yield text
|
yield text
|
||||||
return
|
return
|
||||||
|
|
||||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
overlap_chars = overlap * tokens.CHARS_PER_TOKEN
|
||||||
current = ""
|
current = ""
|
||||||
|
|
||||||
for span in yield_spans(text, max_tokens):
|
for span in yield_spans(text, max_tokens):
|
||||||
current = f"{current} {span}".strip()
|
current = f"{current} {span}".strip()
|
||||||
if approx_token_count(current) < max_tokens:
|
if tokens.approx_token_count(current) < max_tokens:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if overlap <= 0:
|
if overlap <= 0:
|
||||||
|
@ -156,9 +156,11 @@ class Chunk(Base):
|
|||||||
collection_name = Column(Text)
|
collection_name = Column(Text)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
vector: list[float] = []
|
vector: list[float] = []
|
||||||
item_metadata: dict[str, Any] = {}
|
item_metadata: dict[str, Any] = {}
|
||||||
images: list[Image.Image] = []
|
images: list[Image.Image] = []
|
||||||
|
relevance_score: float = 0.0
|
||||||
|
|
||||||
# One of file_path or content must be populated
|
# One of file_path or content must be populated
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
122
src/memory/common/llms.py
Normal file
122
src/memory/common/llms.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from typing import Any
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common import settings, tokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """
|
||||||
|
You are a helpful assistant that creates concise summaries and identifies key topics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image: Image.Image) -> str:
|
||||||
|
"""Encode PIL Image to base64 string."""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
# Convert to RGB if necessary (for RGBA, etc.)
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image.save(buffer, format="JPEG")
|
||||||
|
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
"""Call OpenAI API for summarization."""
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
try:
|
||||||
|
user_content: Any = [{"type": "text", "text": prompt}]
|
||||||
|
if images:
|
||||||
|
for image in images:
|
||||||
|
encoded_image = encode_image(image)
|
||||||
|
user_content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model.split("/")[1],
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt,
|
||||||
|
},
|
||||||
|
{"role": "user", "content": user_content},
|
||||||
|
],
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def call_anthropic(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
"""Call Anthropic API for summarization."""
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||||
|
try:
|
||||||
|
# Prepare the message content
|
||||||
|
content: Any = [{"type": "text", "text": prompt}]
|
||||||
|
if images:
|
||||||
|
# Add images if provided
|
||||||
|
for image in images:
|
||||||
|
encoded_image = encode_image(image)
|
||||||
|
content.append(
|
||||||
|
{ # type: ignore
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": encoded_image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.messages.create(
|
||||||
|
model=model.split("/")[1],
|
||||||
|
messages=[{"role": "user", "content": content}], # type: ignore
|
||||||
|
system=system_prompt,
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
return response.content[0].text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def call(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
if model.startswith("anthropic"):
|
||||||
|
return call_anthropic(prompt, model, images, system_prompt)
|
||||||
|
return call_openai(prompt, model, images, system_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate(content: str, target_tokens: int) -> str:
|
||||||
|
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||||
|
if len(content) > target_chars:
|
||||||
|
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||||
|
return content
|
@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
|||||||
else:
|
else:
|
||||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||||
|
|
||||||
# Search settings
|
# Search settings
|
||||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||||
|
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
||||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from memory.common import settings, chunker
|
from memory.common import settings, tokens, llms
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
|
|||||||
return {"summary": summary, "tags": tags}
|
return {"summary": summary, "tags": tags}
|
||||||
|
|
||||||
|
|
||||||
def _call_openai(prompt: str) -> dict[str, Any]:
|
|
||||||
"""Call OpenAI API for summarization."""
|
|
||||||
import openai
|
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
|
||||||
try:
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
|
||||||
},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
)
|
|
||||||
return parse_response(response.choices[0].message.content or "")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"OpenAI API error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _call_anthropic(prompt: str) -> dict[str, Any]:
|
|
||||||
"""Call Anthropic API for summarization."""
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
|
||||||
try:
|
|
||||||
response = client.messages.create(
|
|
||||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
)
|
|
||||||
return parse_response(response.content[0].text)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Anthropic API error: {e}")
|
|
||||||
logger.error(response.content[0].text)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def truncate(content: str, target_tokens: int) -> str:
|
|
||||||
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
|
|
||||||
if len(content) > target_chars:
|
|
||||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Summarize content to approximately target_tokens length and generate tags.
|
Summarize content to approximately target_tokens length and generate tags.
|
||||||
@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
summary, tags = content, []
|
summary, tags = content, []
|
||||||
|
|
||||||
# If content is already short enough, just extract tags
|
# If content is already short enough, just extract tags
|
||||||
current_tokens = chunker.approx_token_count(content)
|
current_tokens = tokens.approx_token_count(content)
|
||||||
if current_tokens <= target_tokens:
|
if current_tokens <= target_tokens:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Content already under {target_tokens} tokens, extracting tags only"
|
f"Content already under {target_tokens} tokens, extracting tags only"
|
||||||
@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
else:
|
else:
|
||||||
prompt = SUMMARY_PROMPT.format(
|
prompt = SUMMARY_PROMPT.format(
|
||||||
target_tokens=target_tokens,
|
target_tokens=target_tokens,
|
||||||
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
|
target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunker.approx_token_count(prompt) > MAX_TOKENS:
|
if tokens.approx_token_count(prompt) > MAX_TOKENS:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating"
|
f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
|
||||||
)
|
)
|
||||||
prompt = truncate(prompt, MAX_TOKENS - 20)
|
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
|
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
||||||
result = _call_anthropic(prompt)
|
result = parse_response(response)
|
||||||
else:
|
|
||||||
result = _call_openai(prompt)
|
|
||||||
|
|
||||||
summary = result.get("summary", "")
|
summary = result.get("summary", "")
|
||||||
tags = result.get("tags", [])
|
tags = result.get("tags", [])
|
||||||
@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(f"Summarization failed: {e}")
|
logger.error(f"Summarization failed: {e}")
|
||||||
|
|
||||||
tokens = chunker.approx_token_count(summary)
|
summary_tokens = tokens.approx_token_count(summary)
|
||||||
if tokens > target_tokens * 1.5:
|
if summary_tokens > target_tokens * 1.5:
|
||||||
logger.warning(f"Summary too long ({tokens} tokens), truncating")
|
logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
|
||||||
summary = truncate(content, target_tokens)
|
summary = llms.truncate(content, target_tokens)
|
||||||
|
|
||||||
return summary, tags
|
return summary, tags
|
||||||
|
128
src/memory/common/tokens.py
Normal file
128
src/memory/common/tokens.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import logging
|
||||||
|
from PIL import Image
|
||||||
|
import math
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CHARS_PER_TOKEN = 4
|
||||||
|
|
||||||
|
|
||||||
|
def approx_token_count(s: str) -> int:
|
||||||
|
return len(s) // CHARS_PER_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image using OpenAI's counting method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
detail: "high" or "low" detail level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if detail == "low":
|
||||||
|
return 85
|
||||||
|
|
||||||
|
# For high detail, OpenAI resizes the image to fit within 2048x2048
|
||||||
|
# while maintaining aspect ratio, then counts 512x512 tiles
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
# Resize logic to fit within 2048x2048
|
||||||
|
if width > 2048 or height > 2048:
|
||||||
|
if width > height:
|
||||||
|
height = int(height * 2048 / width)
|
||||||
|
width = 2048
|
||||||
|
else:
|
||||||
|
width = int(width * 2048 / height)
|
||||||
|
height = 2048
|
||||||
|
|
||||||
|
# Further resize so shortest side is 768px
|
||||||
|
if width < height:
|
||||||
|
if width > 768:
|
||||||
|
height = int(height * 768 / width)
|
||||||
|
width = 768
|
||||||
|
else:
|
||||||
|
if height > 768:
|
||||||
|
width = int(width * 768 / height)
|
||||||
|
height = 768
|
||||||
|
|
||||||
|
# Count 512x512 tiles
|
||||||
|
tiles_width = math.ceil(width / 512)
|
||||||
|
tiles_height = math.ceil(height / 512)
|
||||||
|
total_tiles = tiles_width * tiles_height
|
||||||
|
|
||||||
|
# Each tile costs 170 tokens, plus 85 base tokens
|
||||||
|
return total_tiles * 170 + 85
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_anthropic_image_tokens(image: Image.Image) -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image using Anthropic's counting method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
# Anthropic's token counting is based on image dimensions
|
||||||
|
# They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
|
||||||
|
# But they also have a base cost per image
|
||||||
|
|
||||||
|
# Rough approximation based on Anthropic's documentation
|
||||||
|
# They count tokens based on the image size after potential resizing
|
||||||
|
total_pixels = width * height
|
||||||
|
|
||||||
|
# Anthropic typically charges around 1.15 tokens per 1000 pixels
|
||||||
|
# with a minimum base cost
|
||||||
|
base_tokens = 100 # Base cost for any image
|
||||||
|
pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
|
||||||
|
|
||||||
|
return base_tokens + pixel_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image based on the model provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
|
||||||
|
detail: Detail level for OpenAI models ("high" or "low")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if model.startswith("anthropic"):
|
||||||
|
return estimate_anthropic_image_tokens(image)
|
||||||
|
else:
|
||||||
|
return estimate_openai_image_tokens(image, detail)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_total_tokens(
|
||||||
|
prompt: str, images: list[Image.Image], model: str, detail: str = "high"
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Estimate total tokens for a prompt with images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt
|
||||||
|
images: List of PIL Images
|
||||||
|
model: Model string
|
||||||
|
detail: Detail level for OpenAI models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated total token count
|
||||||
|
"""
|
||||||
|
# Estimate text tokens
|
||||||
|
text_tokens = approx_token_count(prompt)
|
||||||
|
|
||||||
|
# Estimate image tokens
|
||||||
|
image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
|
||||||
|
|
||||||
|
return text_tokens + image_tokens
|
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
|
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
|
||||||
|
from memory.common.tokens import CHARS_PER_TOKEN
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
|
|||||||
(" ", []), # Just spaces
|
(" ", []), # Just spaces
|
||||||
("\n\t ", []), # Whitespace characters
|
("\n\t ", []), # Whitespace characters
|
||||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||||
@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
|
|||||||
[
|
[
|
||||||
(
|
(
|
||||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"supercalifragilisticexpialidocious",
|
"supercalifragilisticexpialidocious",
|
||||||
["supercalifragilisticexpialidocious"],
|
["supercalifragilisticexpialidocious"],
|
||||||
)
|
),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_long_text(text, expected):
|
def test_yield_word_chunk_long_text(text, expected):
|
||||||
"""Test chunking with long text that exceeds token limits"""
|
"""Test chunking with long text that exceeds token limits"""
|
||||||
@ -41,7 +42,7 @@ def test_yield_word_chunk_single_long_word():
|
|||||||
"""Test behavior with a single word longer than the token limit"""
|
"""Test behavior with a single word longer than the token limit"""
|
||||||
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
|
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
|
||||||
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
|
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
|
||||||
|
|
||||||
chunks = list(yield_word_chunks(long_word, max_tokens))
|
chunks = list(yield_word_chunks(long_word, max_tokens))
|
||||||
# With our changes, this should be a single chunk
|
# With our changes, this should be a single chunk
|
||||||
assert len(chunks) == 1
|
assert len(chunks) == 1
|
||||||
@ -52,8 +53,13 @@ def test_yield_word_chunk_small_token_limit():
|
|||||||
"""Test with a very small max_tokens value to force chunking"""
|
"""Test with a very small max_tokens value to force chunking"""
|
||||||
text = "one two three four five"
|
text = "one two three four five"
|
||||||
max_tokens = 1 # Very small to force chunking after each word
|
max_tokens = 1 # Very small to force chunking after each word
|
||||||
|
|
||||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||||
|
"one two",
|
||||||
|
"three",
|
||||||
|
"four",
|
||||||
|
"five",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
|
|||||||
(
|
(
|
||||||
"word1 word2", # 11 chars with space
|
"word1 word2", # 11 chars with space
|
||||||
3, # 12 chars limit
|
3, # 12 chars limit
|
||||||
["word1 word2"]
|
["word1 word2"],
|
||||||
),
|
),
|
||||||
# Text just over token limit should split
|
# Text just over token limit should split
|
||||||
(
|
(
|
||||||
"word1 word2 word3", # 17 chars with spaces
|
"word1 word2 word3", # 17 chars with spaces
|
||||||
4, # 16 chars limit
|
4, # 16 chars limit
|
||||||
["word1 word2 word3"]
|
["word1 word2 word3"],
|
||||||
),
|
),
|
||||||
# Each word exactly at token limit
|
# Each word exactly at token limit
|
||||||
(
|
(
|
||||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||||
1, # 1 token limit (4 chars)
|
1, # 1 token limit (4 chars)
|
||||||
["aaaa", "bbbb", "cccc"]
|
["aaaa", "bbbb", "cccc"],
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||||
"""Test different combinations of text and token limits"""
|
"""Test different combinations of text and token limits"""
|
||||||
@ -95,15 +101,15 @@ def test_yield_word_chunk_real_world_example():
|
|||||||
"It tries to maximize chunk size while staying under the specified token limit. "
|
"It tries to maximize chunk size while staying under the specified token limit. "
|
||||||
"This behavior is essential for processing large documents efficiently."
|
"This behavior is essential for processing large documents efficiently."
|
||||||
)
|
)
|
||||||
|
|
||||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||||
'The yield_word_chunks function splits text',
|
"The yield_word_chunks function splits text",
|
||||||
'into chunks based on word boundaries. It',
|
"into chunks based on word boundaries. It",
|
||||||
'tries to maximize chunk size while staying',
|
"tries to maximize chunk size while staying",
|
||||||
'under the specified token limit. This',
|
"under the specified token limit. This",
|
||||||
'behavior is essential for processing large',
|
"behavior is essential for processing large",
|
||||||
'documents efficiently.',
|
"documents efficiently.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
|
|||||||
"text, expected",
|
"text, expected",
|
||||||
[
|
[
|
||||||
("", []), # Empty text should yield nothing
|
("", []), # Empty text should yield nothing
|
||||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
(
|
||||||
|
"Simple paragraph",
|
||||||
|
["Simple paragraph"],
|
||||||
|
), # Single paragraph under token limit
|
||||||
(" ", []), # Just whitespace
|
(" ", []), # Just whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_spans_basic_behavior(text, expected):
|
def test_yield_spans_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of yield_spans with various inputs"""
|
"""Test basic behavior of yield_spans with various inputs"""
|
||||||
@ -135,7 +144,7 @@ def test_yield_spans_sentences():
|
|||||||
sentence1 = "Short sentence one." # ~20 chars
|
sentence1 = "Short sentence one." # ~20 chars
|
||||||
sentence2 = "Another short sentence." # ~24 chars
|
sentence2 = "Another short sentence." # ~24 chars
|
||||||
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
|
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
|
||||||
|
|
||||||
# Function should now preserve punctuation
|
# Function should now preserve punctuation
|
||||||
expected = ["Short sentence one.", "Another short sentence."]
|
expected = ["Short sentence one.", "Another short sentence."]
|
||||||
assert list(yield_spans(text, max_tokens)) == expected
|
assert list(yield_spans(text, max_tokens)) == expected
|
||||||
@ -145,8 +154,14 @@ def test_yield_spans_words():
|
|||||||
"""Test splitting by words when sentences exceed token limit"""
|
"""Test splitting by words when sentences exceed token limit"""
|
||||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||||
|
|
||||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
assert list(yield_spans(long_sentence, max_tokens)) == [
|
||||||
|
"This sentence",
|
||||||
|
"has several",
|
||||||
|
"words and needs",
|
||||||
|
"word-level",
|
||||||
|
"chunking.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_yield_spans_complex_document():
|
def test_yield_spans_complex_document():
|
||||||
@ -158,7 +173,7 @@ def test_yield_spans_complex_document():
|
|||||||
"This one is longer and might need word splitting depending on the limit.\n\n"
|
"This one is longer and might need word splitting depending on the limit.\n\n"
|
||||||
"Final short paragraph."
|
"Final short paragraph."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert list(yield_spans(text, max_tokens)) == [
|
assert list(yield_spans(text, max_tokens)) == [
|
||||||
"Paragraph one with a short sentence.",
|
"Paragraph one with a short sentence.",
|
||||||
"And another sentence that should be split.",
|
"And another sentence that should be split.",
|
||||||
@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
|
|||||||
"Some are short.",
|
"Some are short.",
|
||||||
"This one is longer and might need word",
|
"This one is longer and might need word",
|
||||||
"splitting depending on the limit.",
|
"splitting depending on the limit.",
|
||||||
"Final short paragraph."
|
"Final short paragraph.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -175,21 +190,25 @@ def test_yield_spans_very_long_word():
|
|||||||
"""Test with a word that exceeds the token limit"""
|
"""Test with a word that exceeds the token limit"""
|
||||||
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
|
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
|
||||||
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
|
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
|
||||||
|
|
||||||
assert list(yield_spans(long_word, max_tokens)) == [long_word]
|
assert list(yield_spans(long_word, max_tokens)) == [long_word]
|
||||||
|
|
||||||
|
|
||||||
def test_yield_spans_with_punctuation():
|
def test_yield_spans_with_punctuation():
|
||||||
"""Test sentence splitting with various punctuation"""
|
"""Test sentence splitting with various punctuation"""
|
||||||
text = "First sentence! Second sentence? Third sentence."
|
text = "First sentence! Second sentence? Third sentence."
|
||||||
|
|
||||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
assert list(yield_spans(text, max_tokens=10)) == [
|
||||||
|
"First sentence!",
|
||||||
|
"Second sentence?",
|
||||||
|
"Third sentence.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_yield_spans_edge_cases():
|
def test_yield_spans_edge_cases():
|
||||||
"""Test edge cases like empty paragraphs, single character paragraphs"""
|
"""Test edge cases like empty paragraphs, single character paragraphs"""
|
||||||
text = "\n\nA\n\n\n\nB\n\n"
|
text = "\n\nA\n\n\n\nB\n\n"
|
||||||
|
|
||||||
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
|
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
|
||||||
|
|
||||||
|
|
||||||
@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
|
|||||||
("", []), # Empty text
|
("", []), # Empty text
|
||||||
("Short text", ["Short text"]), # Text below token limit
|
("Short text", ["Short text"]), # Text below token limit
|
||||||
(" ", []), # Just whitespace
|
(" ", []), # Just whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_chunk_text_basic_behavior(text, expected):
|
def test_chunk_text_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of chunk_text with various inputs"""
|
"""Test basic behavior of chunk_text with various inputs"""
|
||||||
@ -223,63 +242,58 @@ def test_chunk_text_long_text():
|
|||||||
# Create a long text that will need multiple chunks
|
# Create a long text that will need multiple chunks
|
||||||
sentences = [f"This is sentence {i:02}." for i in range(50)]
|
sentences = [f"This is sentence {i:02}." for i in range(50)]
|
||||||
text = " ".join(sentences)
|
text = " ".join(sentences)
|
||||||
|
|
||||||
max_tokens = 10 # 10 tokens = ~40 chars
|
max_tokens = 10 # 10 tokens = ~40 chars
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
||||||
] + [
|
] + ["This is sentence 49."]
|
||||||
'This is sentence 49.'
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_with_overlap():
|
def test_chunk_text_with_overlap():
|
||||||
"""Test chunking with overlap between chunks"""
|
"""Test chunking with overlap between chunks"""
|
||||||
# Create text with distinct parts to test overlap
|
# Create text with distinct parts to test overlap
|
||||||
text = "Part A. Part B. Part C. Part D. Part E."
|
text = "Part A. Part B. Part C. Part D. Part E."
|
||||||
|
|
||||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
||||||
|
"Part A. Part B. Part C.",
|
||||||
|
"Part C. Part D. Part E.",
|
||||||
|
"Part E.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_zero_overlap():
|
def test_chunk_text_zero_overlap():
|
||||||
"""Test chunking with zero overlap"""
|
"""Test chunking with zero overlap"""
|
||||||
text = "Part A. Part B. Part C. Part D. Part E."
|
text = "Part A. Part B. Part C. Part D. Part E."
|
||||||
|
|
||||||
# 2 tokens = ~8 chars
|
# 2 tokens = ~8 chars
|
||||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
||||||
|
"Part A. Part B.",
|
||||||
|
"Part C. Part D.",
|
||||||
|
"Part E.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_clean_break():
|
def test_chunk_text_clean_break():
|
||||||
"""Test that chunking attempts to break at sentence boundaries"""
|
"""Test that chunking attempts to break at sentence boundaries"""
|
||||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||||
|
|
||||||
max_tokens = 5 # Enough for about 2 sentences
|
max_tokens = 5 # Enough for about 2 sentences
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
||||||
|
"First sentence. Second sentence.",
|
||||||
|
"Third sentence. Fourth sentence.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_very_long_sentences():
|
def test_chunk_text_very_long_sentences():
|
||||||
"""Test with very long sentences that exceed the token limit"""
|
"""Test with very long sentences that exceed the token limit"""
|
||||||
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
||||||
|
|
||||||
max_tokens = 5 # Small limit to force splitting
|
max_tokens = 5 # Small limit to force splitting
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||||
'This is a very long sentence with many many',
|
"This is a very long sentence with many many",
|
||||||
'words that will definitely exceed the',
|
"words that will definitely exceed the",
|
||||||
'token limit we set for',
|
"token limit we set for",
|
||||||
'this particular test',
|
"this particular test",
|
||||||
'case and should be split into multiple',
|
"case and should be split into multiple",
|
||||||
'chunks by the function.',
|
"chunks by the function.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"string, expected_count",
|
|
||||||
[
|
|
||||||
("", 0),
|
|
||||||
("a" * CHARS_PER_TOKEN, 1),
|
|
||||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
|
||||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
|
||||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
|
||||||
]
|
|
||||||
)
|
|
||||||
def test_approx_token_count(string, expected_count):
|
|
||||||
assert approx_token_count(string) == expected_count
|
|
||||||
|
16
tests/memory/common/test_tokens.py
Normal file
16
tests/memory/common/test_tokens.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import pytest
|
||||||
|
from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"string, expected_count",
|
||||||
|
[
|
||||||
|
("", 0),
|
||||||
|
("a" * CHARS_PER_TOKEN, 1),
|
||||||
|
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||||
|
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||||
|
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_approx_token_count(string, expected_count):
|
||||||
|
assert approx_token_count(string) == expected_count
|
Loading…
x
Reference in New Issue
Block a user