mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-31 07:06:07 +02:00
Compare commits
5 Commits
06eec621c1
...
387bd962e6
Author | SHA1 | Date | |
---|---|---|---|
![]() |
387bd962e6 | ||
![]() |
1276b83ffb | ||
![]() |
510cfdf82f | ||
![]() |
96c2f22b16 | ||
![]() |
8eb6374cac |
@ -156,7 +156,9 @@ services:
|
|||||||
STATIC_DIR: "/app/static"
|
STATIC_DIR: "/app/static"
|
||||||
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
||||||
ENABLE_BM25_SEARCH: false
|
ENABLE_BM25_SEARCH: false
|
||||||
secrets: [postgres_password]
|
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||||
|
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||||
|
secrets: [postgres_password, openai_key, anthropic_key]
|
||||||
volumes:
|
volumes:
|
||||||
- ./memory_files:/app/memory_files:rw
|
- ./memory_files:/app/memory_files:rw
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
@ -40,8 +40,7 @@ const Search = () => {
|
|||||||
|
|
||||||
setIsLoading(true)
|
setIsLoading(true)
|
||||||
try {
|
try {
|
||||||
console.log(params)
|
const searchResults = await searchKnowledgeBase(params.query, params.modalities, params.filters, params.config)
|
||||||
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
|
|
||||||
setResults(searchResults || [])
|
setResults(searchResults || [])
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Search error:', error)
|
console.error('Search error:', error)
|
||||||
|
@ -10,12 +10,17 @@ type Filter = {
|
|||||||
[key: string]: any
|
[key: string]: any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SearchConfig = {
|
||||||
|
previews: boolean
|
||||||
|
useScores: boolean
|
||||||
|
limit: number
|
||||||
|
}
|
||||||
|
|
||||||
export interface SearchParams {
|
export interface SearchParams {
|
||||||
query: string
|
query: string
|
||||||
previews: boolean
|
|
||||||
modalities: string[]
|
modalities: string[]
|
||||||
filters: Filter
|
filters: Filter
|
||||||
limit: number
|
config: SearchConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
interface SearchFormProps {
|
interface SearchFormProps {
|
||||||
@ -40,6 +45,7 @@ const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
|
|||||||
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||||
const [query, setQuery] = useState('')
|
const [query, setQuery] = useState('')
|
||||||
const [previews, setPreviews] = useState(false)
|
const [previews, setPreviews] = useState(false)
|
||||||
|
const [useScores, setUseScores] = useState(false)
|
||||||
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
||||||
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
||||||
const [tags, setTags] = useState<Record<string, boolean>>({})
|
const [tags, setTags] = useState<Record<string, boolean>>({})
|
||||||
@ -68,13 +74,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
|||||||
|
|
||||||
onSearch({
|
onSearch({
|
||||||
query,
|
query,
|
||||||
previews,
|
|
||||||
modalities: getSelectedItems(modalities),
|
modalities: getSelectedItems(modalities),
|
||||||
|
config: {
|
||||||
|
previews,
|
||||||
|
useScores,
|
||||||
|
limit
|
||||||
|
},
|
||||||
filters: {
|
filters: {
|
||||||
tags: getSelectedItems(tags),
|
tags: getSelectedItems(tags),
|
||||||
...cleanFilters(dynamicFilters)
|
...cleanFilters(dynamicFilters)
|
||||||
},
|
},
|
||||||
limit
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,6 +114,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
|||||||
Include content previews
|
Include content previews
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
<div className="search-option">
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={useScores}
|
||||||
|
onChange={(e) => setUseScores(e.target.checked)}
|
||||||
|
/>
|
||||||
|
Score results with a LLM before returning
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
<SelectableTags
|
<SelectableTags
|
||||||
title="Modalities"
|
title="Modalities"
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import { useState, useEffect } from 'react'
|
|
||||||
import ReactMarkdown from 'react-markdown'
|
import ReactMarkdown from 'react-markdown'
|
||||||
import { useMCP } from '@/hooks/useMCP'
|
|
||||||
import { SERVER_URL } from '@/hooks/useAuth'
|
import { SERVER_URL } from '@/hooks/useAuth'
|
||||||
|
|
||||||
export type SearchItem = {
|
export type SearchItem = {
|
||||||
@ -74,25 +72,14 @@ export const MarkdownResult = ({ filename, content, chunks, tags, metadata }: Se
|
|||||||
|
|
||||||
export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
||||||
const title = metadata?.title || filename || 'Untitled'
|
const title = metadata?.title || filename || 'Untitled'
|
||||||
const { fetchFile } = useMCP()
|
|
||||||
const [mime_type, setMimeType] = useState<string>()
|
|
||||||
const [content, setContent] = useState<string>()
|
|
||||||
useEffect(() => {
|
|
||||||
const fetchImage = async () => {
|
|
||||||
const files = await fetchFile(filename)
|
|
||||||
const {mime_type, content} = files[0]
|
|
||||||
setMimeType(mime_type)
|
|
||||||
setContent(content)
|
|
||||||
}
|
|
||||||
fetchImage()
|
|
||||||
}, [filename])
|
|
||||||
return (
|
return (
|
||||||
<div className="search-result-card">
|
<div className="search-result-card">
|
||||||
<h4>{title}</h4>
|
<h4>{title}</h4>
|
||||||
<Tag tags={tags} />
|
<Tag tags={tags} />
|
||||||
<Metadata metadata={metadata} />
|
<Metadata metadata={metadata} />
|
||||||
<div className="image-container">
|
<div className="image-container">
|
||||||
{mime_type && mime_type?.startsWith('image/') && <img src={`data:${mime_type};base64,${content}`} alt={title} className="search-result-image"/>}
|
<img src={`${SERVER_URL}/files/${filename}`} alt={title} className="search-result-image"/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
@ -151,13 +151,12 @@ export const useMCP = () => {
|
|||||||
return (await mcpCall('get_metadata_schemas'))[0]
|
return (await mcpCall('get_metadata_schemas'))[0]
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
|
const searchKnowledgeBase = useCallback(async (query: string, modalities: string[] = [], filters: Record<string, any> = {}, config: Record<string, any> = {}) => {
|
||||||
return await mcpCall('search_knowledge_base', {
|
return await mcpCall('search_knowledge_base', {
|
||||||
query,
|
query,
|
||||||
filters,
|
filters,
|
||||||
|
config,
|
||||||
modalities,
|
modalities,
|
||||||
previews,
|
|
||||||
limit,
|
|
||||||
})
|
})
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
|
@ -6,17 +6,16 @@ import base64
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
from mcp.server.fastmcp.resources.base import Resource
|
|
||||||
|
|
||||||
from memory.api.MCP.tools import mcp
|
from memory.api.MCP.tools import mcp
|
||||||
from memory.api.search.search import SearchFilters, search
|
from memory.api.search.search import search
|
||||||
|
from memory.api.search.types import SearchFilters, SearchConfig
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||||
from memory.common.celery_app import app as celery_app
|
from memory.common.celery_app import app as celery_app
|
||||||
@ -80,22 +79,32 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
|||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def search_knowledge_base(
|
async def search_knowledge_base(
|
||||||
query: str,
|
query: str,
|
||||||
filters: dict[str, Any],
|
filters: SearchFilters,
|
||||||
|
config: SearchConfig = SearchConfig(),
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
previews: bool = False,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Search user's stored content including emails, documents, articles, books.
|
Search user's stored content including emails, documents, articles, books.
|
||||||
Use to find specific information the user has saved or received.
|
Use to find specific information the user has saved or received.
|
||||||
Combine with search_observations for complete user context.
|
Combine with search_observations for complete user context.
|
||||||
|
Use the `get_metadata_schemas` tool to get the metadata schema for each collection, from which you can infer the keys for the filters dictionary.
|
||||||
|
|
||||||
|
If you know what kind of data you're looking for, it's worth explicitly filtering by that modality, as this gives better results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Natural language search query - be descriptive about what you're looking for
|
query: Natural language search query - be descriptive about what you're looking for
|
||||||
previews: Include actual content in results - when false only a snippet is returned
|
|
||||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||||
filters: Filter by tags, source_ids, etc.
|
filters: a dictionary with the following keys:
|
||||||
limit: Max results (1-100)
|
- tags: a list of tags to filter by
|
||||||
|
- source_ids: a list of source ids to filter by
|
||||||
|
- min_size: the minimum size of the content to filter by
|
||||||
|
- max_size: the maximum size of the content to filter by
|
||||||
|
- min_created_at: the minimum created_at date to filter by
|
||||||
|
- max_created_at: the maximum created_at date to filter by
|
||||||
|
config: a dictionary with the following keys:
|
||||||
|
- limit: the maximum number of results to return
|
||||||
|
- previews: whether to include the actual content in the results (up to MAX_PREVIEW_LENGTH characters)
|
||||||
|
- useScores: whether to score the results with a LLM before returning - this results in better results but is slower
|
||||||
|
|
||||||
Returns: List of search results with id, score, chunks, content, filename
|
Returns: List of search results with id, score, chunks, content, filename
|
||||||
Higher scores (>0.7) indicate strong matches.
|
Higher scores (>0.7) indicate strong matches.
|
||||||
@ -112,10 +121,9 @@ async def search_knowledge_base(
|
|||||||
upload_data = extract.extract_text(query, skip_summary=True)
|
upload_data = extract.extract_text(query, skip_summary=True)
|
||||||
results = await search(
|
results = await search(
|
||||||
upload_data,
|
upload_data,
|
||||||
previews=previews,
|
|
||||||
modalities=modalities,
|
modalities=modalities,
|
||||||
limit=limit,
|
|
||||||
filters=search_filters,
|
filters=search_filters,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result.model_dump() for result in results]
|
return [result.model_dump() for result in results]
|
||||||
@ -211,7 +219,7 @@ async def search_observations(
|
|||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
observation_types: list[str] | None = None,
|
observation_types: list[str] | None = None,
|
||||||
min_confidences: dict[str, float] = {},
|
min_confidences: dict[str, float] = {},
|
||||||
limit: int = 10,
|
config: SearchConfig = SearchConfig(),
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Search recorded observations about the user.
|
Search recorded observations about the user.
|
||||||
@ -224,7 +232,7 @@ async def search_observations(
|
|||||||
tags: Filter by tags (must have at least one matching tag)
|
tags: Filter by tags (must have at least one matching tag)
|
||||||
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
||||||
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
||||||
limit: Max results (1-100)
|
config: SearchConfig
|
||||||
|
|
||||||
Returns: List with content, tags, created_at, metadata
|
Returns: List with content, tags, created_at, metadata
|
||||||
Results sorted by relevance to your query.
|
Results sorted by relevance to your query.
|
||||||
@ -246,9 +254,7 @@ async def search_observations(
|
|||||||
extract.DataChunk(data=[semantic_text]),
|
extract.DataChunk(data=[semantic_text]),
|
||||||
extract.DataChunk(data=[temporal]),
|
extract.DataChunk(data=[temporal]),
|
||||||
],
|
],
|
||||||
previews=True,
|
|
||||||
modalities={"semantic", "temporal"},
|
modalities={"semantic", "temporal"},
|
||||||
limit=limit,
|
|
||||||
filters=SearchFilters(
|
filters=SearchFilters(
|
||||||
subject=subject,
|
subject=subject,
|
||||||
min_confidences=min_confidences,
|
min_confidences=min_confidences,
|
||||||
@ -256,7 +262,7 @@ async def search_observations(
|
|||||||
observation_types=observation_types,
|
observation_types=observation_types,
|
||||||
source_ids=filter_observation_source_ids(tags=tags),
|
source_ids=filter_observation_source_ids(tags=tags),
|
||||||
),
|
),
|
||||||
timeout=2,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@ -351,11 +357,14 @@ def fetch_file(filename: str) -> dict:
|
|||||||
Text content as string, binary as base64.
|
Text content as string, binary as base64.
|
||||||
"""
|
"""
|
||||||
path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/")
|
path = settings.FILE_STORAGE_DIR / filename.strip().lstrip("/")
|
||||||
|
print("fetching file", path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise FileNotFoundError(f"File not found: {filename}")
|
raise FileNotFoundError(f"File not found: {filename}")
|
||||||
|
|
||||||
mime_type = extract.get_mime_type(path)
|
mime_type = extract.get_mime_type(path)
|
||||||
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
|
chunks = extract.extract_data_chunks(mime_type, path, skip_summary=True)
|
||||||
|
print("mime_type", mime_type)
|
||||||
|
print("chunks", chunks)
|
||||||
|
|
||||||
def serialize_chunk(
|
def serialize_chunk(
|
||||||
chunk: extract.DataChunk, data: extract.MulitmodalChunk
|
chunk: extract.DataChunk, data: extract.MulitmodalChunk
|
||||||
|
74
src/memory/api/search/scorer.py
Normal file
74
src/memory/api/search/scorer.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import asyncio
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common.db.models.source_item import Chunk
|
||||||
|
from memory.common import llms, settings, tokens
|
||||||
|
|
||||||
|
|
||||||
|
SCORE_CHUNK_SYSTEM_PROMPT = """
|
||||||
|
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
||||||
|
|
||||||
|
You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCORE_CHUNK_PROMPT = """
|
||||||
|
Here is the query:
|
||||||
|
<query>{query}</query>
|
||||||
|
|
||||||
|
Here is the chunk:
|
||||||
|
<chunk>
|
||||||
|
{chunk}
|
||||||
|
</chunk>
|
||||||
|
|
||||||
|
Please return your score as a number between 0 and 1 formatted as:
|
||||||
|
<score>your score</score>
|
||||||
|
|
||||||
|
Please always return a summary of any images provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||||
|
try:
|
||||||
|
data = chunk.data
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting chunk data: {e}, {type(e)}")
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
||||||
|
images = [image for image in data if isinstance(image, Image.Image)]
|
||||||
|
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||||
|
try:
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
llms.call,
|
||||||
|
prompt,
|
||||||
|
settings.RANKER_MODEL,
|
||||||
|
images=images,
|
||||||
|
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error scoring chunk: {e}, {type(e)}")
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
soup = BeautifulSoup(response, "html.parser")
|
||||||
|
if not (score := soup.find("score")):
|
||||||
|
chunk.relevance_score = 0.0
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
chunk.relevance_score = float(score.text.strip())
|
||||||
|
except ValueError:
|
||||||
|
chunk.relevance_score = 0.0
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def rank_chunks(
|
||||||
|
query: str, chunks: list[Chunk], min_score: float = 0
|
||||||
|
) -> list[Chunk]:
|
||||||
|
calls = [score_chunk(query, chunk) for chunk in chunks]
|
||||||
|
scored = await asyncio.gather(*calls)
|
||||||
|
return sorted(
|
||||||
|
[chunk for chunk in scored if chunk.relevance_score >= min_score],
|
||||||
|
key=lambda x: x.relevance_score or 0,
|
||||||
|
reverse=True,
|
||||||
|
)
|
@ -12,11 +12,12 @@ from memory.common.db.connection import make_session
|
|||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.common.collections import ALL_COLLECTIONS
|
from memory.common.collections import ALL_COLLECTIONS
|
||||||
from memory.api.search.embeddings import search_chunks_embeddings
|
from memory.api.search.embeddings import search_chunks_embeddings
|
||||||
|
from memory.api.search import scorer
|
||||||
|
|
||||||
if settings.ENABLE_BM25_SEARCH:
|
if settings.ENABLE_BM25_SEARCH:
|
||||||
from memory.api.search.bm25 import search_bm25_chunks
|
from memory.api.search.bm25 import search_bm25_chunks
|
||||||
|
|
||||||
from memory.api.search.types import SearchFilters, SearchResult
|
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -40,7 +41,14 @@ async def search_chunks(
|
|||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
chunks = (
|
chunks = (
|
||||||
db.query(Chunk)
|
db.query(Chunk)
|
||||||
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
|
.options(
|
||||||
|
load_only(
|
||||||
|
Chunk.id, # type: ignore
|
||||||
|
Chunk.source_id, # type: ignore
|
||||||
|
Chunk.content, # type: ignore
|
||||||
|
Chunk.file_paths, # type: ignore
|
||||||
|
)
|
||||||
|
)
|
||||||
.filter(Chunk.id.in_(all_ids))
|
.filter(Chunk.id.in_(all_ids))
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
@ -49,7 +57,7 @@ async def search_chunks(
|
|||||||
|
|
||||||
|
|
||||||
async def search_sources(
|
async def search_sources(
|
||||||
chunks: list[Chunk], previews: Optional[bool] = False
|
chunks: list[Chunk], previews: bool = False
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
by_source = defaultdict(list)
|
by_source = defaultdict(list)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
@ -65,11 +73,9 @@ async def search_sources(
|
|||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
previews: Optional[bool] = False,
|
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
limit: int = 10,
|
|
||||||
filters: SearchFilters = {},
|
filters: SearchFilters = {},
|
||||||
timeout: int = 2,
|
config: SearchConfig = SearchConfig(),
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search across knowledge base using text query and optional files.
|
Search across knowledge base using text query and optional files.
|
||||||
@ -87,8 +93,13 @@ async def search(
|
|||||||
chunks = await search_chunks(
|
chunks = await search_chunks(
|
||||||
data,
|
data,
|
||||||
allowed_modalities,
|
allowed_modalities,
|
||||||
limit,
|
config.limit,
|
||||||
filters,
|
filters,
|
||||||
timeout,
|
config.timeout,
|
||||||
)
|
)
|
||||||
return await search_sources(chunks, previews)
|
if settings.ENABLE_SEARCH_SCORING and config.useScores:
|
||||||
|
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||||
|
|
||||||
|
sources = await search_sources(chunks, config.previews)
|
||||||
|
sources.sort(key=lambda x: x.search_score or 0, reverse=True)
|
||||||
|
return sources[: config.limit]
|
||||||
|
@ -2,10 +2,9 @@ from datetime import datetime
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, TypedDict, NotRequired, cast
|
from typing import Optional, TypedDict, NotRequired, cast
|
||||||
|
|
||||||
from memory.common.db.models.source_item import SourceItem
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from memory.common.db.models import Chunk
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -32,6 +31,7 @@ class SearchResult(BaseModel):
|
|||||||
tags: list[str] | None = None
|
tags: list[str] | None = None
|
||||||
metadata: dict | None = None
|
metadata: dict | None = None
|
||||||
created_at: datetime | None = None
|
created_at: datetime | None = None
|
||||||
|
search_score: float | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_source_item(
|
def from_source_item(
|
||||||
@ -41,6 +41,8 @@ class SearchResult(BaseModel):
|
|||||||
metadata.pop("content", None)
|
metadata.pop("content", None)
|
||||||
chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
|
chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
|
||||||
|
|
||||||
|
search_score = sum(chunk.relevance_score for chunk in chunks)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
id=cast(int, source.id),
|
id=cast(int, source.id),
|
||||||
size=cast(int, source.size),
|
size=cast(int, source.size),
|
||||||
@ -56,6 +58,7 @@ class SearchResult(BaseModel):
|
|||||||
tags=cast(list[str], source.tags),
|
tags=cast(list[str], source.tags),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
created_at=cast(datetime | None, source.inserted_at),
|
created_at=cast(datetime | None, source.inserted_at),
|
||||||
|
search_score=search_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -65,3 +68,10 @@ class SearchFilters(TypedDict):
|
|||||||
min_confidences: NotRequired[dict[str, float]]
|
min_confidences: NotRequired[dict[str, float]]
|
||||||
observation_types: NotRequired[list[str] | None]
|
observation_types: NotRequired[list[str] | None]
|
||||||
source_ids: NotRequired[list[int] | None]
|
source_ids: NotRequired[list[int] | None]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchConfig(BaseModel):
|
||||||
|
limit: int = 10
|
||||||
|
timeout: int = 20
|
||||||
|
previews: bool = False
|
||||||
|
useScores: bool = False
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
from typing import Iterable, Any
|
from typing import Iterable, Any
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings, tokens
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
|
|||||||
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
||||||
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
||||||
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
||||||
CHARS_PER_TOKEN = 4
|
|
||||||
|
|
||||||
|
|
||||||
Vector = list[float]
|
Vector = list[float]
|
||||||
@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
|
|||||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||||
|
|
||||||
|
|
||||||
def approx_token_count(s: str) -> int:
|
|
||||||
return len(s) // CHARS_PER_TOKEN
|
|
||||||
|
|
||||||
|
|
||||||
def yield_word_chunks(
|
def yield_word_chunks(
|
||||||
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
@ -36,7 +31,7 @@ def yield_word_chunks(
|
|||||||
current = ""
|
current = ""
|
||||||
for word in words:
|
for word in words:
|
||||||
new_chunk = f"{current} {word}".strip()
|
new_chunk = f"{current} {word}".strip()
|
||||||
if current and approx_token_count(new_chunk) > max_tokens:
|
if current and tokens.approx_token_count(new_chunk) > max_tokens:
|
||||||
yield current
|
yield current
|
||||||
current = word
|
current = word
|
||||||
else:
|
else:
|
||||||
@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
|||||||
if not paragraph.strip():
|
if not paragraph.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if approx_token_count(paragraph) <= max_tokens:
|
if tokens.approx_token_count(paragraph) <= max_tokens:
|
||||||
yield paragraph
|
yield paragraph
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
|||||||
if not sentence.strip():
|
if not sentence.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if approx_token_count(sentence) <= max_tokens:
|
if tokens.approx_token_count(sentence) <= max_tokens:
|
||||||
yield sentence
|
yield sentence
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -99,16 +94,16 @@ def chunk_text(
|
|||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
if approx_token_count(text) <= max_tokens:
|
if tokens.approx_token_count(text) <= max_tokens:
|
||||||
yield text
|
yield text
|
||||||
return
|
return
|
||||||
|
|
||||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
overlap_chars = overlap * tokens.CHARS_PER_TOKEN
|
||||||
current = ""
|
current = ""
|
||||||
|
|
||||||
for span in yield_spans(text, max_tokens):
|
for span in yield_spans(text, max_tokens):
|
||||||
current = f"{current} {span}".strip()
|
current = f"{current} {span}".strip()
|
||||||
if approx_token_count(current) < max_tokens:
|
if tokens.approx_token_count(current) < max_tokens:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if overlap <= 0:
|
if overlap <= 0:
|
||||||
|
@ -156,9 +156,11 @@ class Chunk(Base):
|
|||||||
collection_name = Column(Text)
|
collection_name = Column(Text)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
vector: list[float] = []
|
vector: list[float] = []
|
||||||
item_metadata: dict[str, Any] = {}
|
item_metadata: dict[str, Any] = {}
|
||||||
images: list[Image.Image] = []
|
images: list[Image.Image] = []
|
||||||
|
relevance_score: float = 0.0
|
||||||
|
|
||||||
# One of file_path or content must be populated
|
# One of file_path or content must be populated
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
122
src/memory/common/llms.py
Normal file
122
src/memory/common/llms.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from typing import Any
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common import settings, tokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """
|
||||||
|
You are a helpful assistant that creates concise summaries and identifies key topics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image: Image.Image) -> str:
|
||||||
|
"""Encode PIL Image to base64 string."""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
# Convert to RGB if necessary (for RGBA, etc.)
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image.save(buffer, format="JPEG")
|
||||||
|
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
"""Call OpenAI API for summarization."""
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
try:
|
||||||
|
user_content: Any = [{"type": "text", "text": prompt}]
|
||||||
|
if images:
|
||||||
|
for image in images:
|
||||||
|
encoded_image = encode_image(image)
|
||||||
|
user_content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model.split("/")[1],
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt,
|
||||||
|
},
|
||||||
|
{"role": "user", "content": user_content},
|
||||||
|
],
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def call_anthropic(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
"""Call Anthropic API for summarization."""
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||||
|
try:
|
||||||
|
# Prepare the message content
|
||||||
|
content: Any = [{"type": "text", "text": prompt}]
|
||||||
|
if images:
|
||||||
|
# Add images if provided
|
||||||
|
for image in images:
|
||||||
|
encoded_image = encode_image(image)
|
||||||
|
content.append(
|
||||||
|
{ # type: ignore
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": encoded_image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.messages.create(
|
||||||
|
model=model.split("/")[1],
|
||||||
|
messages=[{"role": "user", "content": content}], # type: ignore
|
||||||
|
system=system_prompt,
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
return response.content[0].text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def call(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
if model.startswith("anthropic"):
|
||||||
|
return call_anthropic(prompt, model, images, system_prompt)
|
||||||
|
return call_openai(prompt, model, images, system_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate(content: str, target_tokens: int) -> str:
|
||||||
|
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||||
|
if len(content) > target_chars:
|
||||||
|
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||||
|
return content
|
@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
|||||||
else:
|
else:
|
||||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||||
|
|
||||||
# Search settings
|
# Search settings
|
||||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||||
|
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
||||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from memory.common import settings, chunker
|
from memory.common import settings, tokens, llms
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
|
|||||||
return {"summary": summary, "tags": tags}
|
return {"summary": summary, "tags": tags}
|
||||||
|
|
||||||
|
|
||||||
def _call_openai(prompt: str) -> dict[str, Any]:
|
|
||||||
"""Call OpenAI API for summarization."""
|
|
||||||
import openai
|
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
|
||||||
try:
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
|
||||||
},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
)
|
|
||||||
return parse_response(response.choices[0].message.content or "")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"OpenAI API error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _call_anthropic(prompt: str) -> dict[str, Any]:
|
|
||||||
"""Call Anthropic API for summarization."""
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
|
||||||
try:
|
|
||||||
response = client.messages.create(
|
|
||||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
)
|
|
||||||
return parse_response(response.content[0].text)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Anthropic API error: {e}")
|
|
||||||
logger.error(response.content[0].text)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def truncate(content: str, target_tokens: int) -> str:
|
|
||||||
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
|
|
||||||
if len(content) > target_chars:
|
|
||||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Summarize content to approximately target_tokens length and generate tags.
|
Summarize content to approximately target_tokens length and generate tags.
|
||||||
@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
summary, tags = content, []
|
summary, tags = content, []
|
||||||
|
|
||||||
# If content is already short enough, just extract tags
|
# If content is already short enough, just extract tags
|
||||||
current_tokens = chunker.approx_token_count(content)
|
current_tokens = tokens.approx_token_count(content)
|
||||||
if current_tokens <= target_tokens:
|
if current_tokens <= target_tokens:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Content already under {target_tokens} tokens, extracting tags only"
|
f"Content already under {target_tokens} tokens, extracting tags only"
|
||||||
@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
else:
|
else:
|
||||||
prompt = SUMMARY_PROMPT.format(
|
prompt = SUMMARY_PROMPT.format(
|
||||||
target_tokens=target_tokens,
|
target_tokens=target_tokens,
|
||||||
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
|
target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunker.approx_token_count(prompt) > MAX_TOKENS:
|
if tokens.approx_token_count(prompt) > MAX_TOKENS:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating"
|
f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
|
||||||
)
|
)
|
||||||
prompt = truncate(prompt, MAX_TOKENS - 20)
|
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
|
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
||||||
result = _call_anthropic(prompt)
|
result = parse_response(response)
|
||||||
else:
|
|
||||||
result = _call_openai(prompt)
|
|
||||||
|
|
||||||
summary = result.get("summary", "")
|
summary = result.get("summary", "")
|
||||||
tags = result.get("tags", [])
|
tags = result.get("tags", [])
|
||||||
@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(f"Summarization failed: {e}")
|
logger.error(f"Summarization failed: {e}")
|
||||||
|
|
||||||
tokens = chunker.approx_token_count(summary)
|
summary_tokens = tokens.approx_token_count(summary)
|
||||||
if tokens > target_tokens * 1.5:
|
if summary_tokens > target_tokens * 1.5:
|
||||||
logger.warning(f"Summary too long ({tokens} tokens), truncating")
|
logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
|
||||||
summary = truncate(content, target_tokens)
|
summary = llms.truncate(content, target_tokens)
|
||||||
|
|
||||||
return summary, tags
|
return summary, tags
|
||||||
|
128
src/memory/common/tokens.py
Normal file
128
src/memory/common/tokens.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import logging
|
||||||
|
from PIL import Image
|
||||||
|
import math
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CHARS_PER_TOKEN = 4
|
||||||
|
|
||||||
|
|
||||||
|
def approx_token_count(s: str) -> int:
|
||||||
|
return len(s) // CHARS_PER_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image using OpenAI's counting method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
detail: "high" or "low" detail level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if detail == "low":
|
||||||
|
return 85
|
||||||
|
|
||||||
|
# For high detail, OpenAI resizes the image to fit within 2048x2048
|
||||||
|
# while maintaining aspect ratio, then counts 512x512 tiles
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
# Resize logic to fit within 2048x2048
|
||||||
|
if width > 2048 or height > 2048:
|
||||||
|
if width > height:
|
||||||
|
height = int(height * 2048 / width)
|
||||||
|
width = 2048
|
||||||
|
else:
|
||||||
|
width = int(width * 2048 / height)
|
||||||
|
height = 2048
|
||||||
|
|
||||||
|
# Further resize so shortest side is 768px
|
||||||
|
if width < height:
|
||||||
|
if width > 768:
|
||||||
|
height = int(height * 768 / width)
|
||||||
|
width = 768
|
||||||
|
else:
|
||||||
|
if height > 768:
|
||||||
|
width = int(width * 768 / height)
|
||||||
|
height = 768
|
||||||
|
|
||||||
|
# Count 512x512 tiles
|
||||||
|
tiles_width = math.ceil(width / 512)
|
||||||
|
tiles_height = math.ceil(height / 512)
|
||||||
|
total_tiles = tiles_width * tiles_height
|
||||||
|
|
||||||
|
# Each tile costs 170 tokens, plus 85 base tokens
|
||||||
|
return total_tiles * 170 + 85
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_anthropic_image_tokens(image: Image.Image) -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image using Anthropic's counting method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
# Anthropic's token counting is based on image dimensions
|
||||||
|
# They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
|
||||||
|
# But they also have a base cost per image
|
||||||
|
|
||||||
|
# Rough approximation based on Anthropic's documentation
|
||||||
|
# They count tokens based on the image size after potential resizing
|
||||||
|
total_pixels = width * height
|
||||||
|
|
||||||
|
# Anthropic typically charges around 1.15 tokens per 1000 pixels
|
||||||
|
# with a minimum base cost
|
||||||
|
base_tokens = 100 # Base cost for any image
|
||||||
|
pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
|
||||||
|
|
||||||
|
return base_tokens + pixel_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image based on the model provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
|
||||||
|
detail: Detail level for OpenAI models ("high" or "low")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if model.startswith("anthropic"):
|
||||||
|
return estimate_anthropic_image_tokens(image)
|
||||||
|
else:
|
||||||
|
return estimate_openai_image_tokens(image, detail)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_total_tokens(
|
||||||
|
prompt: str, images: list[Image.Image], model: str, detail: str = "high"
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Estimate total tokens for a prompt with images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt
|
||||||
|
images: List of PIL Images
|
||||||
|
model: Model string
|
||||||
|
detail: Detail level for OpenAI models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated total token count
|
||||||
|
"""
|
||||||
|
# Estimate text tokens
|
||||||
|
text_tokens = approx_token_count(prompt)
|
||||||
|
|
||||||
|
# Estimate image tokens
|
||||||
|
image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
|
||||||
|
|
||||||
|
return text_tokens + image_tokens
|
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
|
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
|
||||||
|
from memory.common.tokens import CHARS_PER_TOKEN
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
|
|||||||
(" ", []), # Just spaces
|
(" ", []), # Just spaces
|
||||||
("\n\t ", []), # Whitespace characters
|
("\n\t ", []), # Whitespace characters
|
||||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||||
@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
|
|||||||
[
|
[
|
||||||
(
|
(
|
||||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"supercalifragilisticexpialidocious",
|
"supercalifragilisticexpialidocious",
|
||||||
["supercalifragilisticexpialidocious"],
|
["supercalifragilisticexpialidocious"],
|
||||||
)
|
),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_long_text(text, expected):
|
def test_yield_word_chunk_long_text(text, expected):
|
||||||
"""Test chunking with long text that exceeds token limits"""
|
"""Test chunking with long text that exceeds token limits"""
|
||||||
@ -53,7 +54,12 @@ def test_yield_word_chunk_small_token_limit():
|
|||||||
text = "one two three four five"
|
text = "one two three four five"
|
||||||
max_tokens = 1 # Very small to force chunking after each word
|
max_tokens = 1 # Very small to force chunking after each word
|
||||||
|
|
||||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||||
|
"one two",
|
||||||
|
"three",
|
||||||
|
"four",
|
||||||
|
"five",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
|
|||||||
(
|
(
|
||||||
"word1 word2", # 11 chars with space
|
"word1 word2", # 11 chars with space
|
||||||
3, # 12 chars limit
|
3, # 12 chars limit
|
||||||
["word1 word2"]
|
["word1 word2"],
|
||||||
),
|
),
|
||||||
# Text just over token limit should split
|
# Text just over token limit should split
|
||||||
(
|
(
|
||||||
"word1 word2 word3", # 17 chars with spaces
|
"word1 word2 word3", # 17 chars with spaces
|
||||||
4, # 16 chars limit
|
4, # 16 chars limit
|
||||||
["word1 word2 word3"]
|
["word1 word2 word3"],
|
||||||
),
|
),
|
||||||
# Each word exactly at token limit
|
# Each word exactly at token limit
|
||||||
(
|
(
|
||||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||||
1, # 1 token limit (4 chars)
|
1, # 1 token limit (4 chars)
|
||||||
["aaaa", "bbbb", "cccc"]
|
["aaaa", "bbbb", "cccc"],
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||||
"""Test different combinations of text and token limits"""
|
"""Test different combinations of text and token limits"""
|
||||||
@ -98,12 +104,12 @@ def test_yield_word_chunk_real_world_example():
|
|||||||
|
|
||||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||||
'The yield_word_chunks function splits text',
|
"The yield_word_chunks function splits text",
|
||||||
'into chunks based on word boundaries. It',
|
"into chunks based on word boundaries. It",
|
||||||
'tries to maximize chunk size while staying',
|
"tries to maximize chunk size while staying",
|
||||||
'under the specified token limit. This',
|
"under the specified token limit. This",
|
||||||
'behavior is essential for processing large',
|
"behavior is essential for processing large",
|
||||||
'documents efficiently.',
|
"documents efficiently.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
|
|||||||
"text, expected",
|
"text, expected",
|
||||||
[
|
[
|
||||||
("", []), # Empty text should yield nothing
|
("", []), # Empty text should yield nothing
|
||||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
(
|
||||||
|
"Simple paragraph",
|
||||||
|
["Simple paragraph"],
|
||||||
|
), # Single paragraph under token limit
|
||||||
(" ", []), # Just whitespace
|
(" ", []), # Just whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_spans_basic_behavior(text, expected):
|
def test_yield_spans_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of yield_spans with various inputs"""
|
"""Test basic behavior of yield_spans with various inputs"""
|
||||||
@ -146,7 +155,13 @@ def test_yield_spans_words():
|
|||||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||||
|
|
||||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
assert list(yield_spans(long_sentence, max_tokens)) == [
|
||||||
|
"This sentence",
|
||||||
|
"has several",
|
||||||
|
"words and needs",
|
||||||
|
"word-level",
|
||||||
|
"chunking.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_yield_spans_complex_document():
|
def test_yield_spans_complex_document():
|
||||||
@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
|
|||||||
"Some are short.",
|
"Some are short.",
|
||||||
"This one is longer and might need word",
|
"This one is longer and might need word",
|
||||||
"splitting depending on the limit.",
|
"splitting depending on the limit.",
|
||||||
"Final short paragraph."
|
"Final short paragraph.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -183,7 +198,11 @@ def test_yield_spans_with_punctuation():
|
|||||||
"""Test sentence splitting with various punctuation"""
|
"""Test sentence splitting with various punctuation"""
|
||||||
text = "First sentence! Second sentence? Third sentence."
|
text = "First sentence! Second sentence? Third sentence."
|
||||||
|
|
||||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
assert list(yield_spans(text, max_tokens=10)) == [
|
||||||
|
"First sentence!",
|
||||||
|
"Second sentence?",
|
||||||
|
"Third sentence.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_yield_spans_edge_cases():
|
def test_yield_spans_edge_cases():
|
||||||
@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
|
|||||||
("", []), # Empty text
|
("", []), # Empty text
|
||||||
("Short text", ["Short text"]), # Text below token limit
|
("Short text", ["Short text"]), # Text below token limit
|
||||||
(" ", []), # Just whitespace
|
(" ", []), # Just whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_chunk_text_basic_behavior(text, expected):
|
def test_chunk_text_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of chunk_text with various inputs"""
|
"""Test basic behavior of chunk_text with various inputs"""
|
||||||
@ -226,10 +245,8 @@ def test_chunk_text_long_text():
|
|||||||
|
|
||||||
max_tokens = 10 # 10 tokens = ~40 chars
|
max_tokens = 10 # 10 tokens = ~40 chars
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
||||||
] + [
|
] + ["This is sentence 49."]
|
||||||
'This is sentence 49.'
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_with_overlap():
|
def test_chunk_text_with_overlap():
|
||||||
@ -237,7 +254,11 @@ def test_chunk_text_with_overlap():
|
|||||||
# Create text with distinct parts to test overlap
|
# Create text with distinct parts to test overlap
|
||||||
text = "Part A. Part B. Part C. Part D. Part E."
|
text = "Part A. Part B. Part C. Part D. Part E."
|
||||||
|
|
||||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
||||||
|
"Part A. Part B. Part C.",
|
||||||
|
"Part C. Part D. Part E.",
|
||||||
|
"Part E.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_zero_overlap():
|
def test_chunk_text_zero_overlap():
|
||||||
@ -245,7 +266,11 @@ def test_chunk_text_zero_overlap():
|
|||||||
text = "Part A. Part B. Part C. Part D. Part E."
|
text = "Part A. Part B. Part C. Part D. Part E."
|
||||||
|
|
||||||
# 2 tokens = ~8 chars
|
# 2 tokens = ~8 chars
|
||||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
||||||
|
"Part A. Part B.",
|
||||||
|
"Part C. Part D.",
|
||||||
|
"Part E.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_clean_break():
|
def test_chunk_text_clean_break():
|
||||||
@ -253,7 +278,10 @@ def test_chunk_text_clean_break():
|
|||||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||||
|
|
||||||
max_tokens = 5 # Enough for about 2 sentences
|
max_tokens = 5 # Enough for about 2 sentences
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
||||||
|
"First sentence. Second sentence.",
|
||||||
|
"Third sentence. Fourth sentence.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_very_long_sentences():
|
def test_chunk_text_very_long_sentences():
|
||||||
@ -262,24 +290,10 @@ def test_chunk_text_very_long_sentences():
|
|||||||
|
|
||||||
max_tokens = 5 # Small limit to force splitting
|
max_tokens = 5 # Small limit to force splitting
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||||
'This is a very long sentence with many many',
|
"This is a very long sentence with many many",
|
||||||
'words that will definitely exceed the',
|
"words that will definitely exceed the",
|
||||||
'token limit we set for',
|
"token limit we set for",
|
||||||
'this particular test',
|
"this particular test",
|
||||||
'case and should be split into multiple',
|
"case and should be split into multiple",
|
||||||
'chunks by the function.',
|
"chunks by the function.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"string, expected_count",
|
|
||||||
[
|
|
||||||
("", 0),
|
|
||||||
("a" * CHARS_PER_TOKEN, 1),
|
|
||||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
|
||||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
|
||||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
|
||||||
]
|
|
||||||
)
|
|
||||||
def test_approx_token_count(string, expected_count):
|
|
||||||
assert approx_token_count(string) == expected_count
|
|
||||||
|
16
tests/memory/common/test_tokens.py
Normal file
16
tests/memory/common/test_tokens.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import pytest
|
||||||
|
from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"string, expected_count",
|
||||||
|
[
|
||||||
|
("", 0),
|
||||||
|
("a" * CHARS_PER_TOKEN, 1),
|
||||||
|
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||||
|
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||||
|
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_approx_token_count(string, expected_count):
|
||||||
|
assert approx_token_count(string) == expected_count
|
Loading…
x
Reference in New Issue
Block a user