diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 2f99e0e..892e336 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -1,3 +1,4 @@ +import hashlib import logging import time from typing import Literal, cast @@ -15,6 +16,49 @@ from memory.common.db.models import Chunk, SourceItem logger = logging.getLogger(__name__) +# Simple TTL cache for query embeddings to avoid repeated API calls +# Key: (query_hash, model), Value: (embedding, timestamp) +_query_embedding_cache: dict[tuple[str, str], tuple[Vector, float]] = {} +_CACHE_TTL_SECONDS = 300 # 5 minutes +_CACHE_MAX_SIZE = 100 + + +def _get_query_cache_key(query_text: str, model: str) -> tuple[str, str]: + """Generate cache key from query text and model.""" + query_hash = hashlib.md5(query_text.encode()).hexdigest() + return (query_hash, model) + + +def _get_cached_embedding(query_text: str, model: str) -> Vector | None: + """Get cached embedding if it exists and hasn't expired.""" + key = _get_query_cache_key(query_text, model) + if key in _query_embedding_cache: + embedding, timestamp = _query_embedding_cache[key] + if time.time() - timestamp < _CACHE_TTL_SECONDS: + return embedding + else: + del _query_embedding_cache[key] + return None + + +def _cache_embedding(query_text: str, model: str, embedding: Vector) -> None: + """Cache an embedding with current timestamp.""" + # Evict old entries if cache is full + if len(_query_embedding_cache) >= _CACHE_MAX_SIZE: + # Remove oldest entries + now = time.time() + expired = [k for k, (_, ts) in _query_embedding_cache.items() + if now - ts > _CACHE_TTL_SECONDS] + for k in expired: + del _query_embedding_cache[k] + # If still full, remove oldest + if len(_query_embedding_cache) >= _CACHE_MAX_SIZE: + oldest = min(_query_embedding_cache.items(), key=lambda x: x[1][1]) + del _query_embedding_cache[oldest[0]] + + key = _get_query_cache_key(query_text, model) + _query_embedding_cache[key] = (embedding, time.time()) + class EmbeddingError(Exception): """Raised when embedding generation fails after retries.""" @@ -115,7 +159,22 @@ def embed_text( if not any(chunked_chunks): return [] - return embed_chunks(chunked_chunks, model, input_type) + # For queries, check cache first + if input_type == "query" and len(chunked_chunks) == 1: + query_text = as_string(chunked_chunks[0]) + cached = _get_cached_embedding(query_text, model) + if cached is not None: + logger.debug(f"Query embedding cache hit for model {model}") + return [cached] + + vectors = embed_chunks(chunked_chunks, model, input_type) + + # Cache query embeddings + if input_type == "query" and len(chunked_chunks) == 1 and vectors: + query_text = as_string(chunked_chunks[0]) + _cache_embedding(query_text, model, vectors[0]) + + return vectors def embed_mixed( @@ -125,7 +184,25 @@ def embed_mixed( chunk_size: int = DEFAULT_CHUNK_TOKENS, ) -> list[Vector]: chunked_chunks = [break_chunk(item, chunk_size) for item in items if item.data] - return embed_chunks(chunked_chunks, model, input_type) + if not chunked_chunks: + return [] + + # For queries, check cache first + if input_type == "query" and len(chunked_chunks) == 1: + query_text = as_string(chunked_chunks[0]) + cached = _get_cached_embedding(query_text, model) + if cached is not None: + logger.debug(f"Query embedding cache hit for model {model}") + return [cached] + + vectors = embed_chunks(chunked_chunks, model, input_type) + + # Cache query embeddings + if input_type == "query" and len(chunked_chunks) == 1 and vectors: + query_text = as_string(chunked_chunks[0]) + _cache_embedding(query_text, model, vectors[0]) + + return vectors def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]: