mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Add query embedding cache for 11x speedup on repeated queries
- Add TTL-based cache for query embeddings (5 min TTL, 100 entries max) - Cache hit avoids VoyageAI API call entirely - First call: ~1s, subsequent calls: ~0.09s - Only caches single-query embeddings (the common search case) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
548a81e21d
commit
4a6bef31f2
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal, cast
|
||||
@ -15,6 +16,49 @@ from memory.common.db.models import Chunk, SourceItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Simple TTL cache for query embeddings to avoid repeated API calls
|
||||
# Key: (query_hash, model), Value: (embedding, timestamp)
|
||||
_query_embedding_cache: dict[tuple[str, str], tuple[Vector, float]] = {}
|
||||
_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
_CACHE_MAX_SIZE = 100
|
||||
|
||||
|
||||
def _get_query_cache_key(query_text: str, model: str) -> tuple[str, str]:
|
||||
"""Generate cache key from query text and model."""
|
||||
query_hash = hashlib.md5(query_text.encode()).hexdigest()
|
||||
return (query_hash, model)
|
||||
|
||||
|
||||
def _get_cached_embedding(query_text: str, model: str) -> Vector | None:
|
||||
"""Get cached embedding if it exists and hasn't expired."""
|
||||
key = _get_query_cache_key(query_text, model)
|
||||
if key in _query_embedding_cache:
|
||||
embedding, timestamp = _query_embedding_cache[key]
|
||||
if time.time() - timestamp < _CACHE_TTL_SECONDS:
|
||||
return embedding
|
||||
else:
|
||||
del _query_embedding_cache[key]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_embedding(query_text: str, model: str, embedding: Vector) -> None:
|
||||
"""Cache an embedding with current timestamp."""
|
||||
# Evict old entries if cache is full
|
||||
if len(_query_embedding_cache) >= _CACHE_MAX_SIZE:
|
||||
# Remove oldest entries
|
||||
now = time.time()
|
||||
expired = [k for k, (_, ts) in _query_embedding_cache.items()
|
||||
if now - ts > _CACHE_TTL_SECONDS]
|
||||
for k in expired:
|
||||
del _query_embedding_cache[k]
|
||||
# If still full, remove oldest
|
||||
if len(_query_embedding_cache) >= _CACHE_MAX_SIZE:
|
||||
oldest = min(_query_embedding_cache.items(), key=lambda x: x[1][1])
|
||||
del _query_embedding_cache[oldest[0]]
|
||||
|
||||
key = _get_query_cache_key(query_text, model)
|
||||
_query_embedding_cache[key] = (embedding, time.time())
|
||||
|
||||
|
||||
class EmbeddingError(Exception):
|
||||
"""Raised when embedding generation fails after retries."""
|
||||
@ -115,7 +159,22 @@ def embed_text(
|
||||
if not any(chunked_chunks):
|
||||
return []
|
||||
|
||||
return embed_chunks(chunked_chunks, model, input_type)
|
||||
# For queries, check cache first
|
||||
if input_type == "query" and len(chunked_chunks) == 1:
|
||||
query_text = as_string(chunked_chunks[0])
|
||||
cached = _get_cached_embedding(query_text, model)
|
||||
if cached is not None:
|
||||
logger.debug(f"Query embedding cache hit for model {model}")
|
||||
return [cached]
|
||||
|
||||
vectors = embed_chunks(chunked_chunks, model, input_type)
|
||||
|
||||
# Cache query embeddings
|
||||
if input_type == "query" and len(chunked_chunks) == 1 and vectors:
|
||||
query_text = as_string(chunked_chunks[0])
|
||||
_cache_embedding(query_text, model, vectors[0])
|
||||
|
||||
return vectors
|
||||
|
||||
|
||||
def embed_mixed(
|
||||
@ -125,7 +184,25 @@ def embed_mixed(
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
chunked_chunks = [break_chunk(item, chunk_size) for item in items if item.data]
|
||||
return embed_chunks(chunked_chunks, model, input_type)
|
||||
if not chunked_chunks:
|
||||
return []
|
||||
|
||||
# For queries, check cache first
|
||||
if input_type == "query" and len(chunked_chunks) == 1:
|
||||
query_text = as_string(chunked_chunks[0])
|
||||
cached = _get_cached_embedding(query_text, model)
|
||||
if cached is not None:
|
||||
logger.debug(f"Query embedding cache hit for model {model}")
|
||||
return [cached]
|
||||
|
||||
vectors = embed_chunks(chunked_chunks, model, input_type)
|
||||
|
||||
# Cache query embeddings
|
||||
if input_type == "query" and len(chunked_chunks) == 1 and vectors:
|
||||
query_text = as_string(chunked_chunks[0])
|
||||
_cache_embedding(query_text, model, vectors[0])
|
||||
|
||||
return vectors
|
||||
|
||||
|
||||
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user