diff --git a/src/memory/api/search/bm25.py b/src/memory/api/search/bm25.py index 8672e80..351b898 100644 --- a/src/memory/api/search/bm25.py +++ b/src/memory/api/search/bm25.py @@ -53,7 +53,7 @@ def build_tsquery(query: str) -> str: words = [ w.strip().lower() for w in clean_query.split() - if w.strip() and len(w.strip()) > 2 and w.strip().lower() not in _STOPWORDS + if w.strip() and len(w.strip()) >= 2 and w.strip().lower() not in _STOPWORDS ] if not words: return "" diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py index 739b60b..ff9a3e0 100644 --- a/src/memory/api/search/search.py +++ b/src/memory/api/search/search.py @@ -7,7 +7,6 @@ import logging import math from collections import defaultdict from datetime import datetime, timezone -from sqlalchemy.orm import load_only from memory.common import extract, settings from memory.common.db.connection import make_session from memory.common.db.models import Chunk, SourceItem @@ -321,22 +320,31 @@ async def _run_searches( use_bm25: bool, ) -> dict[str, float]: """ - Run embedding and optionally BM25 searches, returning fused scores. + Run embedding and optionally BM25 searches in parallel, returning fused scores. """ - # Run embedding search - embedding_scores = await search_chunks_embeddings( + # Build tasks to run in parallel + embedding_task = search_chunks_embeddings( search_data, modalities, internal_limit, filters, timeout ) - # Run BM25 search if enabled - bm25_scores: dict[str, float] = {} if use_bm25: - try: - bm25_scores = await search_bm25_chunks( - data, modalities, internal_limit, filters, timeout - ) - except asyncio.TimeoutError: - logger.warning("BM25 search timed out, using embedding results only") + # Run both searches in parallel + results = await asyncio.gather( + embedding_task, + search_bm25_chunks(data, modalities, internal_limit, filters, timeout), + return_exceptions=True, + ) + + embedding_scores = results[0] if not isinstance(results[0], Exception) else {} + if isinstance(results[0], Exception): + logger.warning(f"Embedding search failed: {results[0]}") + + bm25_scores = results[1] if not isinstance(results[1], Exception) else {} + if isinstance(results[1], Exception): + logger.warning(f"BM25 search failed: {results[1]}") + else: + embedding_scores = await embedding_task + bm25_scores = {} # Fuse scores from both methods using Reciprocal Rank Fusion return fuse_scores_rrf(embedding_scores, bm25_scores) @@ -367,14 +375,6 @@ def _fetch_chunks( with make_session() as db: chunks = ( db.query(Chunk) - .options( - load_only( - Chunk.id, # type: ignore - Chunk.source_id, # type: ignore - Chunk.content, # type: ignore - Chunk.file_paths, # type: ignore - ) - ) .filter(Chunk.id.in_(top_ids)) .all() )