mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-29 07:34:43 +02:00
second pass in search
This commit is contained in:
parent
06eec621c1
commit
8eb6374cac
@ -9,6 +9,7 @@ const SearchResults = ({ results, isLoading }: { results: any[], isLoading: bool
|
|||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <Loading message="Searching..." />
|
return <Loading message="Searching..." />
|
||||||
}
|
}
|
||||||
|
console.log("results",results)
|
||||||
return (
|
return (
|
||||||
<div className="search-results">
|
<div className="search-results">
|
||||||
{results.length > 0 && (
|
{results.length > 0 && (
|
||||||
|
65
src/memory/api/search/scorer.py
Normal file
65
src/memory/api/search/scorer.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import asyncio
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common.db.models.source_item import Chunk
|
||||||
|
from memory.common import llms, settings, tokens
|
||||||
|
|
||||||
|
|
||||||
|
SCORE_CHUNK_SYSTEM_PROMPT = """
|
||||||
|
You are a helpful assistant that scores how relevant a chunk of text and/or image is to a query.
|
||||||
|
|
||||||
|
You are given a query and a chunk of text and/or an image. The chunk should be relevant to the query, but often won't be. Score the chunk based on how relevant it is to the query and assign a score on a gradient between 0 and 1, which is the probability that the chunk is relevant to the query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCORE_CHUNK_PROMPT = """
|
||||||
|
Here is the query:
|
||||||
|
<query>{query}</query>
|
||||||
|
|
||||||
|
Here is the chunk:
|
||||||
|
<chunk>
|
||||||
|
{chunk}
|
||||||
|
</chunk>
|
||||||
|
|
||||||
|
Please return your score as a number between 0 and 1 formatted as:
|
||||||
|
<score>your score</score>
|
||||||
|
|
||||||
|
Please always return a summary of any images provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def score_chunk(query: str, chunk: Chunk) -> Chunk:
|
||||||
|
data = chunk.data
|
||||||
|
chunk_text = "\n".join(text for text in data if isinstance(text, str))
|
||||||
|
images = [image for image in data if isinstance(image, Image.Image)]
|
||||||
|
prompt = SCORE_CHUNK_PROMPT.format(query=query, chunk=chunk_text)
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
llms.call,
|
||||||
|
prompt,
|
||||||
|
settings.RANKER_MODEL,
|
||||||
|
images=images,
|
||||||
|
system_prompt=SCORE_CHUNK_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
soup = BeautifulSoup(response, "html.parser")
|
||||||
|
if not (score := soup.find("score")):
|
||||||
|
chunk.relevance_score = 0.0
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
chunk.relevance_score = float(score.text.strip())
|
||||||
|
except ValueError:
|
||||||
|
chunk.relevance_score = 0.0
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def rank_chunks(
|
||||||
|
query: str, chunks: list[Chunk], min_score: float = 0
|
||||||
|
) -> list[Chunk]:
|
||||||
|
calls = [score_chunk(query, chunk) for chunk in chunks]
|
||||||
|
scored = await asyncio.gather(*calls)
|
||||||
|
return sorted(
|
||||||
|
[chunk for chunk in scored if chunk.relevance_score >= min_score],
|
||||||
|
key=lambda x: x.relevance_score or 0,
|
||||||
|
reverse=True,
|
||||||
|
)
|
@ -12,6 +12,7 @@ from memory.common.db.connection import make_session
|
|||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.common.collections import ALL_COLLECTIONS
|
from memory.common.collections import ALL_COLLECTIONS
|
||||||
from memory.api.search.embeddings import search_chunks_embeddings
|
from memory.api.search.embeddings import search_chunks_embeddings
|
||||||
|
from memory.api.search import scorer
|
||||||
|
|
||||||
if settings.ENABLE_BM25_SEARCH:
|
if settings.ENABLE_BM25_SEARCH:
|
||||||
from memory.api.search.bm25 import search_bm25_chunks
|
from memory.api.search.bm25 import search_bm25_chunks
|
||||||
@ -40,7 +41,14 @@ async def search_chunks(
|
|||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
chunks = (
|
chunks = (
|
||||||
db.query(Chunk)
|
db.query(Chunk)
|
||||||
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
|
.options(
|
||||||
|
load_only(
|
||||||
|
Chunk.id, # type: ignore
|
||||||
|
Chunk.source_id, # type: ignore
|
||||||
|
Chunk.content, # type: ignore
|
||||||
|
Chunk.file_paths, # type: ignore
|
||||||
|
)
|
||||||
|
)
|
||||||
.filter(Chunk.id.in_(all_ids))
|
.filter(Chunk.id.in_(all_ids))
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
@ -91,4 +99,6 @@ async def search(
|
|||||||
filters,
|
filters,
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
|
if settings.ENABLE_SEARCH_SCORING:
|
||||||
|
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||||
return await search_sources(chunks, previews)
|
return await search_sources(chunks, previews)
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
from typing import Iterable, Any
|
from typing import Iterable, Any
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings, tokens
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
|
|||||||
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
EMBEDDING_MAX_TOKENS = settings.EMBEDDING_MAX_TOKENS
|
||||||
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
DEFAULT_CHUNK_TOKENS = settings.DEFAULT_CHUNK_TOKENS
|
||||||
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
OVERLAP_TOKENS = settings.OVERLAP_TOKENS
|
||||||
CHARS_PER_TOKEN = 4
|
|
||||||
|
|
||||||
|
|
||||||
Vector = list[float]
|
Vector = list[float]
|
||||||
@ -22,10 +21,6 @@ Embedding = tuple[str, Vector, dict[str, Any]]
|
|||||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||||
|
|
||||||
|
|
||||||
def approx_token_count(s: str) -> int:
|
|
||||||
return len(s) // CHARS_PER_TOKEN
|
|
||||||
|
|
||||||
|
|
||||||
def yield_word_chunks(
|
def yield_word_chunks(
|
||||||
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
@ -36,7 +31,7 @@ def yield_word_chunks(
|
|||||||
current = ""
|
current = ""
|
||||||
for word in words:
|
for word in words:
|
||||||
new_chunk = f"{current} {word}".strip()
|
new_chunk = f"{current} {word}".strip()
|
||||||
if current and approx_token_count(new_chunk) > max_tokens:
|
if current and tokens.approx_token_count(new_chunk) > max_tokens:
|
||||||
yield current
|
yield current
|
||||||
current = word
|
current = word
|
||||||
else:
|
else:
|
||||||
@ -65,7 +60,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
|||||||
if not paragraph.strip():
|
if not paragraph.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if approx_token_count(paragraph) <= max_tokens:
|
if tokens.approx_token_count(paragraph) <= max_tokens:
|
||||||
yield paragraph
|
yield paragraph
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -73,7 +68,7 @@ def yield_spans(text: str, max_tokens: int = DEFAULT_CHUNK_TOKENS) -> Iterable[s
|
|||||||
if not sentence.strip():
|
if not sentence.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if approx_token_count(sentence) <= max_tokens:
|
if tokens.approx_token_count(sentence) <= max_tokens:
|
||||||
yield sentence
|
yield sentence
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -99,16 +94,16 @@ def chunk_text(
|
|||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
if approx_token_count(text) <= max_tokens:
|
if tokens.approx_token_count(text) <= max_tokens:
|
||||||
yield text
|
yield text
|
||||||
return
|
return
|
||||||
|
|
||||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
overlap_chars = overlap * tokens.CHARS_PER_TOKEN
|
||||||
current = ""
|
current = ""
|
||||||
|
|
||||||
for span in yield_spans(text, max_tokens):
|
for span in yield_spans(text, max_tokens):
|
||||||
current = f"{current} {span}".strip()
|
current = f"{current} {span}".strip()
|
||||||
if approx_token_count(current) < max_tokens:
|
if tokens.approx_token_count(current) < max_tokens:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if overlap <= 0:
|
if overlap <= 0:
|
||||||
|
@ -156,9 +156,11 @@ class Chunk(Base):
|
|||||||
collection_name = Column(Text)
|
collection_name = Column(Text)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
vector: list[float] = []
|
vector: list[float] = []
|
||||||
item_metadata: dict[str, Any] = {}
|
item_metadata: dict[str, Any] = {}
|
||||||
images: list[Image.Image] = []
|
images: list[Image.Image] = []
|
||||||
|
relevance_score: float = 0.0
|
||||||
|
|
||||||
# One of file_path or content must be populated
|
# One of file_path or content must be populated
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
122
src/memory/common/llms.py
Normal file
122
src/memory/common/llms.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from typing import Any
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from memory.common import settings, tokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """
|
||||||
|
You are a helpful assistant that creates concise summaries and identifies key topics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image: Image.Image) -> str:
|
||||||
|
"""Encode PIL Image to base64 string."""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
# Convert to RGB if necessary (for RGBA, etc.)
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image.save(buffer, format="JPEG")
|
||||||
|
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def call_openai(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
"""Call OpenAI API for summarization."""
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
try:
|
||||||
|
user_content: Any = [{"type": "text", "text": prompt}]
|
||||||
|
if images:
|
||||||
|
for image in images:
|
||||||
|
encoded_image = encode_image(image)
|
||||||
|
user_content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model.split("/")[1],
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt,
|
||||||
|
},
|
||||||
|
{"role": "user", "content": user_content},
|
||||||
|
],
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def call_anthropic(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
"""Call Anthropic API for summarization."""
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||||
|
try:
|
||||||
|
# Prepare the message content
|
||||||
|
content: Any = [{"type": "text", "text": prompt}]
|
||||||
|
if images:
|
||||||
|
# Add images if provided
|
||||||
|
for image in images:
|
||||||
|
encoded_image = encode_image(image)
|
||||||
|
content.append(
|
||||||
|
{ # type: ignore
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": encoded_image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.messages.create(
|
||||||
|
model=model.split("/")[1],
|
||||||
|
messages=[{"role": "user", "content": content}], # type: ignore
|
||||||
|
system=system_prompt,
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
return response.content[0].text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def call(
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
images: list[Image.Image] = [],
|
||||||
|
system_prompt: str = SYSTEM_PROMPT,
|
||||||
|
) -> str:
|
||||||
|
if model.startswith("anthropic"):
|
||||||
|
return call_anthropic(prompt, model, images, system_prompt)
|
||||||
|
return call_openai(prompt, model, images, system_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate(content: str, target_tokens: int) -> str:
|
||||||
|
target_chars = target_tokens * tokens.CHARS_PER_TOKEN
|
||||||
|
if len(content) > target_chars:
|
||||||
|
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
||||||
|
return content
|
@ -131,10 +131,13 @@ if anthropic_key_file := os.getenv("ANTHROPIC_API_KEY_FILE"):
|
|||||||
else:
|
else:
|
||||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
RANKER_MODEL = os.getenv("RANKER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 200000))
|
||||||
|
|
||||||
# Search settings
|
# Search settings
|
||||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||||
|
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
||||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from memory.common import settings, chunker
|
from memory.common import settings, tokens, llms
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -65,57 +65,6 @@ def parse_response(response: str) -> dict[str, Any]:
|
|||||||
return {"summary": summary, "tags": tags}
|
return {"summary": summary, "tags": tags}
|
||||||
|
|
||||||
|
|
||||||
def _call_openai(prompt: str) -> dict[str, Any]:
|
|
||||||
"""Call OpenAI API for summarization."""
|
|
||||||
import openai
|
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
|
|
||||||
try:
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant that creates concise summaries and identifies key topics.",
|
|
||||||
},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
)
|
|
||||||
return parse_response(response.choices[0].message.content or "")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"OpenAI API error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _call_anthropic(prompt: str) -> dict[str, Any]:
|
|
||||||
"""Call Anthropic API for summarization."""
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
|
||||||
try:
|
|
||||||
response = client.messages.create(
|
|
||||||
model=settings.SUMMARIZER_MODEL.split("/")[1],
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
system="You are a helpful assistant that creates concise summaries and identifies key topics. Always respond with valid XML.",
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=2048,
|
|
||||||
)
|
|
||||||
return parse_response(response.content[0].text)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Anthropic API error: {e}")
|
|
||||||
logger.error(response.content[0].text)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def truncate(content: str, target_tokens: int) -> str:
|
|
||||||
target_chars = target_tokens * chunker.CHARS_PER_TOKEN
|
|
||||||
if len(content) > target_chars:
|
|
||||||
return content[:target_chars].rsplit(" ", 1)[0] + "..."
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Summarize content to approximately target_tokens length and generate tags.
|
Summarize content to approximately target_tokens length and generate tags.
|
||||||
@ -136,7 +85,7 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
summary, tags = content, []
|
summary, tags = content, []
|
||||||
|
|
||||||
# If content is already short enough, just extract tags
|
# If content is already short enough, just extract tags
|
||||||
current_tokens = chunker.approx_token_count(content)
|
current_tokens = tokens.approx_token_count(content)
|
||||||
if current_tokens <= target_tokens:
|
if current_tokens <= target_tokens:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Content already under {target_tokens} tokens, extracting tags only"
|
f"Content already under {target_tokens} tokens, extracting tags only"
|
||||||
@ -145,21 +94,19 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
else:
|
else:
|
||||||
prompt = SUMMARY_PROMPT.format(
|
prompt = SUMMARY_PROMPT.format(
|
||||||
target_tokens=target_tokens,
|
target_tokens=target_tokens,
|
||||||
target_chars=target_tokens * chunker.CHARS_PER_TOKEN,
|
target_chars=target_tokens * tokens.CHARS_PER_TOKEN,
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunker.approx_token_count(prompt) > MAX_TOKENS:
|
if tokens.approx_token_count(prompt) > MAX_TOKENS:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Prompt too long ({chunker.approx_token_count(prompt)} tokens), truncating"
|
f"Prompt too long ({tokens.approx_token_count(prompt)} tokens), truncating"
|
||||||
)
|
)
|
||||||
prompt = truncate(prompt, MAX_TOKENS - 20)
|
prompt = llms.truncate(prompt, MAX_TOKENS - 20)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if settings.SUMMARIZER_MODEL.startswith("anthropic"):
|
response = llms.call(prompt, settings.SUMMARIZER_MODEL)
|
||||||
result = _call_anthropic(prompt)
|
result = parse_response(response)
|
||||||
else:
|
|
||||||
result = _call_openai(prompt)
|
|
||||||
|
|
||||||
summary = result.get("summary", "")
|
summary = result.get("summary", "")
|
||||||
tags = result.get("tags", [])
|
tags = result.get("tags", [])
|
||||||
@ -167,9 +114,9 @@ def summarize(content: str, target_tokens: int | None = None) -> tuple[str, list
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(f"Summarization failed: {e}")
|
logger.error(f"Summarization failed: {e}")
|
||||||
|
|
||||||
tokens = chunker.approx_token_count(summary)
|
summary_tokens = tokens.approx_token_count(summary)
|
||||||
if tokens > target_tokens * 1.5:
|
if summary_tokens > target_tokens * 1.5:
|
||||||
logger.warning(f"Summary too long ({tokens} tokens), truncating")
|
logger.warning(f"Summary too long ({summary_tokens} tokens), truncating")
|
||||||
summary = truncate(content, target_tokens)
|
summary = llms.truncate(content, target_tokens)
|
||||||
|
|
||||||
return summary, tags
|
return summary, tags
|
||||||
|
128
src/memory/common/tokens.py
Normal file
128
src/memory/common/tokens.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import logging
|
||||||
|
from PIL import Image
|
||||||
|
import math
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CHARS_PER_TOKEN = 4
|
||||||
|
|
||||||
|
|
||||||
|
def approx_token_count(s: str) -> int:
|
||||||
|
return len(s) // CHARS_PER_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_openai_image_tokens(image: Image.Image, detail: str = "high") -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image using OpenAI's counting method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
detail: "high" or "low" detail level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if detail == "low":
|
||||||
|
return 85
|
||||||
|
|
||||||
|
# For high detail, OpenAI resizes the image to fit within 2048x2048
|
||||||
|
# while maintaining aspect ratio, then counts 512x512 tiles
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
# Resize logic to fit within 2048x2048
|
||||||
|
if width > 2048 or height > 2048:
|
||||||
|
if width > height:
|
||||||
|
height = int(height * 2048 / width)
|
||||||
|
width = 2048
|
||||||
|
else:
|
||||||
|
width = int(width * 2048 / height)
|
||||||
|
height = 2048
|
||||||
|
|
||||||
|
# Further resize so shortest side is 768px
|
||||||
|
if width < height:
|
||||||
|
if width > 768:
|
||||||
|
height = int(height * 768 / width)
|
||||||
|
width = 768
|
||||||
|
else:
|
||||||
|
if height > 768:
|
||||||
|
width = int(width * 768 / height)
|
||||||
|
height = 768
|
||||||
|
|
||||||
|
# Count 512x512 tiles
|
||||||
|
tiles_width = math.ceil(width / 512)
|
||||||
|
tiles_height = math.ceil(height / 512)
|
||||||
|
total_tiles = tiles_width * tiles_height
|
||||||
|
|
||||||
|
# Each tile costs 170 tokens, plus 85 base tokens
|
||||||
|
return total_tiles * 170 + 85
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_anthropic_image_tokens(image: Image.Image) -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image using Anthropic's counting method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
# Anthropic's token counting is based on image dimensions
|
||||||
|
# They use approximately 1.2 tokens per "tile" where tiles are roughly 1024x1024
|
||||||
|
# But they also have a base cost per image
|
||||||
|
|
||||||
|
# Rough approximation based on Anthropic's documentation
|
||||||
|
# They count tokens based on the image size after potential resizing
|
||||||
|
total_pixels = width * height
|
||||||
|
|
||||||
|
# Anthropic typically charges around 1.15 tokens per 1000 pixels
|
||||||
|
# with a minimum base cost
|
||||||
|
base_tokens = 100 # Base cost for any image
|
||||||
|
pixel_tokens = math.ceil(total_pixels / 1000 * 1.15)
|
||||||
|
|
||||||
|
return base_tokens + pixel_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_image_tokens(image: Image.Image, model: str, detail: str = "high") -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for an image based on the model provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image
|
||||||
|
model: Model string (e.g., "openai/gpt-4-vision-preview", "anthropic/claude-3-sonnet")
|
||||||
|
detail: Detail level for OpenAI models ("high" or "low")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if model.startswith("anthropic"):
|
||||||
|
return estimate_anthropic_image_tokens(image)
|
||||||
|
else:
|
||||||
|
return estimate_openai_image_tokens(image, detail)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_total_tokens(
|
||||||
|
prompt: str, images: list[Image.Image], model: str, detail: str = "high"
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Estimate total tokens for a prompt with images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt
|
||||||
|
images: List of PIL Images
|
||||||
|
model: Model string
|
||||||
|
detail: Detail level for OpenAI models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated total token count
|
||||||
|
"""
|
||||||
|
# Estimate text tokens
|
||||||
|
text_tokens = approx_token_count(prompt)
|
||||||
|
|
||||||
|
# Estimate image tokens
|
||||||
|
image_tokens = sum(estimate_image_tokens(img, model, detail) for img in images)
|
||||||
|
|
||||||
|
return text_tokens + image_tokens
|
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
|
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text
|
||||||
|
from memory.common.tokens import CHARS_PER_TOKEN
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -12,7 +13,7 @@ from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CH
|
|||||||
(" ", []), # Just spaces
|
(" ", []), # Just spaces
|
||||||
("\n\t ", []), # Whitespace characters
|
("\n\t ", []), # Whitespace characters
|
||||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||||
@ -24,13 +25,13 @@ def test_yield_word_chunk_basic_behavior(text, expected):
|
|||||||
[
|
[
|
||||||
(
|
(
|
||||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
["word1 word2 word3 word4", "verylongwordthatexceedsthelimit word5"],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"supercalifragilisticexpialidocious",
|
"supercalifragilisticexpialidocious",
|
||||||
["supercalifragilisticexpialidocious"],
|
["supercalifragilisticexpialidocious"],
|
||||||
)
|
),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_long_text(text, expected):
|
def test_yield_word_chunk_long_text(text, expected):
|
||||||
"""Test chunking with long text that exceeds token limits"""
|
"""Test chunking with long text that exceeds token limits"""
|
||||||
@ -53,7 +54,12 @@ def test_yield_word_chunk_small_token_limit():
|
|||||||
text = "one two three four five"
|
text = "one two three four five"
|
||||||
max_tokens = 1 # Very small to force chunking after each word
|
max_tokens = 1 # Very small to force chunking after each word
|
||||||
|
|
||||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||||
|
"one two",
|
||||||
|
"three",
|
||||||
|
"four",
|
||||||
|
"five",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -67,21 +73,21 @@ def test_yield_word_chunk_small_token_limit():
|
|||||||
(
|
(
|
||||||
"word1 word2", # 11 chars with space
|
"word1 word2", # 11 chars with space
|
||||||
3, # 12 chars limit
|
3, # 12 chars limit
|
||||||
["word1 word2"]
|
["word1 word2"],
|
||||||
),
|
),
|
||||||
# Text just over token limit should split
|
# Text just over token limit should split
|
||||||
(
|
(
|
||||||
"word1 word2 word3", # 17 chars with spaces
|
"word1 word2 word3", # 17 chars with spaces
|
||||||
4, # 16 chars limit
|
4, # 16 chars limit
|
||||||
["word1 word2 word3"]
|
["word1 word2 word3"],
|
||||||
),
|
),
|
||||||
# Each word exactly at token limit
|
# Each word exactly at token limit
|
||||||
(
|
(
|
||||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||||
1, # 1 token limit (4 chars)
|
1, # 1 token limit (4 chars)
|
||||||
["aaaa", "bbbb", "cccc"]
|
["aaaa", "bbbb", "cccc"],
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||||
"""Test different combinations of text and token limits"""
|
"""Test different combinations of text and token limits"""
|
||||||
@ -98,12 +104,12 @@ def test_yield_word_chunk_real_world_example():
|
|||||||
|
|
||||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||||
'The yield_word_chunks function splits text',
|
"The yield_word_chunks function splits text",
|
||||||
'into chunks based on word boundaries. It',
|
"into chunks based on word boundaries. It",
|
||||||
'tries to maximize chunk size while staying',
|
"tries to maximize chunk size while staying",
|
||||||
'under the specified token limit. This',
|
"under the specified token limit. This",
|
||||||
'behavior is essential for processing large',
|
"behavior is essential for processing large",
|
||||||
'documents efficiently.',
|
"documents efficiently.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -112,9 +118,12 @@ def test_yield_word_chunk_real_world_example():
|
|||||||
"text, expected",
|
"text, expected",
|
||||||
[
|
[
|
||||||
("", []), # Empty text should yield nothing
|
("", []), # Empty text should yield nothing
|
||||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
(
|
||||||
|
"Simple paragraph",
|
||||||
|
["Simple paragraph"],
|
||||||
|
), # Single paragraph under token limit
|
||||||
(" ", []), # Just whitespace
|
(" ", []), # Just whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_yield_spans_basic_behavior(text, expected):
|
def test_yield_spans_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of yield_spans with various inputs"""
|
"""Test basic behavior of yield_spans with various inputs"""
|
||||||
@ -146,7 +155,13 @@ def test_yield_spans_words():
|
|||||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||||
|
|
||||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
assert list(yield_spans(long_sentence, max_tokens)) == [
|
||||||
|
"This sentence",
|
||||||
|
"has several",
|
||||||
|
"words and needs",
|
||||||
|
"word-level",
|
||||||
|
"chunking.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_yield_spans_complex_document():
|
def test_yield_spans_complex_document():
|
||||||
@ -167,7 +182,7 @@ def test_yield_spans_complex_document():
|
|||||||
"Some are short.",
|
"Some are short.",
|
||||||
"This one is longer and might need word",
|
"This one is longer and might need word",
|
||||||
"splitting depending on the limit.",
|
"splitting depending on the limit.",
|
||||||
"Final short paragraph."
|
"Final short paragraph.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -183,7 +198,11 @@ def test_yield_spans_with_punctuation():
|
|||||||
"""Test sentence splitting with various punctuation"""
|
"""Test sentence splitting with various punctuation"""
|
||||||
text = "First sentence! Second sentence? Third sentence."
|
text = "First sentence! Second sentence? Third sentence."
|
||||||
|
|
||||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
assert list(yield_spans(text, max_tokens=10)) == [
|
||||||
|
"First sentence!",
|
||||||
|
"Second sentence?",
|
||||||
|
"Third sentence.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_yield_spans_edge_cases():
|
def test_yield_spans_edge_cases():
|
||||||
@ -199,7 +218,7 @@ def test_yield_spans_edge_cases():
|
|||||||
("", []), # Empty text
|
("", []), # Empty text
|
||||||
("Short text", ["Short text"]), # Text below token limit
|
("Short text", ["Short text"]), # Text below token limit
|
||||||
(" ", []), # Just whitespace
|
(" ", []), # Just whitespace
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_chunk_text_basic_behavior(text, expected):
|
def test_chunk_text_basic_behavior(text, expected):
|
||||||
"""Test basic behavior of chunk_text with various inputs"""
|
"""Test basic behavior of chunk_text with various inputs"""
|
||||||
@ -226,10 +245,8 @@ def test_chunk_text_long_text():
|
|||||||
|
|
||||||
max_tokens = 10 # 10 tokens = ~40 chars
|
max_tokens = 10 # 10 tokens = ~40 chars
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
f"This is sentence {i:02}. This is sentence {i + 1:02}." for i in range(49)
|
||||||
] + [
|
] + ["This is sentence 49."]
|
||||||
'This is sentence 49.'
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_with_overlap():
|
def test_chunk_text_with_overlap():
|
||||||
@ -237,7 +254,11 @@ def test_chunk_text_with_overlap():
|
|||||||
# Create text with distinct parts to test overlap
|
# Create text with distinct parts to test overlap
|
||||||
text = "Part A. Part B. Part C. Part D. Part E."
|
text = "Part A. Part B. Part C. Part D. Part E."
|
||||||
|
|
||||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
assert list(chunk_text(text, max_tokens=4, overlap=3)) == [
|
||||||
|
"Part A. Part B. Part C.",
|
||||||
|
"Part C. Part D. Part E.",
|
||||||
|
"Part E.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_zero_overlap():
|
def test_chunk_text_zero_overlap():
|
||||||
@ -245,7 +266,11 @@ def test_chunk_text_zero_overlap():
|
|||||||
text = "Part A. Part B. Part C. Part D. Part E."
|
text = "Part A. Part B. Part C. Part D. Part E."
|
||||||
|
|
||||||
# 2 tokens = ~8 chars
|
# 2 tokens = ~8 chars
|
||||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
assert list(chunk_text(text, max_tokens=2, overlap=0)) == [
|
||||||
|
"Part A. Part B.",
|
||||||
|
"Part C. Part D.",
|
||||||
|
"Part E.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_clean_break():
|
def test_chunk_text_clean_break():
|
||||||
@ -253,7 +278,10 @@ def test_chunk_text_clean_break():
|
|||||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||||
|
|
||||||
max_tokens = 5 # Enough for about 2 sentences
|
max_tokens = 5 # Enough for about 2 sentences
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == [
|
||||||
|
"First sentence. Second sentence.",
|
||||||
|
"Third sentence. Fourth sentence.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_very_long_sentences():
|
def test_chunk_text_very_long_sentences():
|
||||||
@ -262,24 +290,10 @@ def test_chunk_text_very_long_sentences():
|
|||||||
|
|
||||||
max_tokens = 5 # Small limit to force splitting
|
max_tokens = 5 # Small limit to force splitting
|
||||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||||
'This is a very long sentence with many many',
|
"This is a very long sentence with many many",
|
||||||
'words that will definitely exceed the',
|
"words that will definitely exceed the",
|
||||||
'token limit we set for',
|
"token limit we set for",
|
||||||
'this particular test',
|
"this particular test",
|
||||||
'case and should be split into multiple',
|
"case and should be split into multiple",
|
||||||
'chunks by the function.',
|
"chunks by the function.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"string, expected_count",
|
|
||||||
[
|
|
||||||
("", 0),
|
|
||||||
("a" * CHARS_PER_TOKEN, 1),
|
|
||||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
|
||||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
|
||||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
|
||||||
]
|
|
||||||
)
|
|
||||||
def test_approx_token_count(string, expected_count):
|
|
||||||
assert approx_token_count(string) == expected_count
|
|
||||||
|
16
tests/memory/common/test_tokens.py
Normal file
16
tests/memory/common/test_tokens.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import pytest
|
||||||
|
from memory.common.tokens import CHARS_PER_TOKEN, approx_token_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"string, expected_count",
|
||||||
|
[
|
||||||
|
("", 0),
|
||||||
|
("a" * CHARS_PER_TOKEN, 1),
|
||||||
|
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||||
|
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||||
|
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_approx_token_count(string, expected_count):
|
||||||
|
assert approx_token_count(string) == expected_count
|
Loading…
x
Reference in New Issue
Block a user