diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py index 5912547..33b9a13 100644 --- a/src/memory/api/search/search.py +++ b/src/memory/api/search/search.py @@ -28,6 +28,9 @@ if settings.ENABLE_RERANKING: from memory.api.search.types import SearchConfig, SearchFilters, SearchResult +# Default config for when none is provided +_DEFAULT_CONFIG = SearchConfig() + logger = logging.getLogger(__name__) # Reciprocal Rank Fusion constant (k parameter) @@ -313,6 +316,7 @@ async def search_chunks( limit: int = 10, filters: SearchFilters = {}, timeout: int = 2, + config: SearchConfig = _DEFAULT_CONFIG, ) -> list[Chunk]: """ Search chunks using embedding similarity and optionally BM25. @@ -322,7 +326,19 @@ async def search_chunks( If HyDE is enabled, also generates a hypothetical document from the query and includes it in the embedding search for better semantic matching. + + Enhancement flags in config override global settings when set: + - useBm25: Enable BM25 lexical search + - useHyde: Enable HyDE query expansion + - useReranking: Enable cross-encoder reranking + - useQueryExpansion: Enable synonym/abbreviation expansion """ + # Resolve enhancement flags: config overrides global settings + use_bm25 = config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH + use_hyde = config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION + use_reranking = config.useReranking if config.useReranking is not None else settings.ENABLE_RERANKING + use_query_expansion = config.useQueryExpansion if config.useQueryExpansion is not None else True + # Search for more candidates than requested, fuse scores, then return top N # This helps find results that rank well in one method but not the other internal_limit = limit * CANDIDATE_MULTIPLIER @@ -331,17 +347,20 @@ async def search_chunks( query_text = " ".join( c for chunk in data for c in chunk.data if isinstance(c, str) ) - expanded_query = expand_query(query_text) - # If query was expanded, use expanded version for search - if expanded_query != query_text: - logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'") - search_data = [extract.DataChunk(data=[expanded_query])] + if use_query_expansion: + expanded_query = expand_query(query_text) + # If query was expanded, use expanded version for search + if expanded_query != query_text: + logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'") + search_data = [extract.DataChunk(data=[expanded_query])] + else: + search_data = list(data) # Copy to avoid modifying original else: - search_data = list(data) # Copy to avoid modifying original + search_data = list(data) # Apply HyDE expansion if enabled - if settings.ENABLE_HYDE_EXPANSION: + if use_hyde: # Only expand queries with 4+ words (short queries are usually specific enough) if len(query_text.split()) >= 4: try: @@ -361,7 +380,7 @@ async def search_chunks( # Run BM25 search if enabled bm25_scores: dict[str, float] = {} - if settings.ENABLE_BM25_SEARCH: + if use_bm25: try: bm25_scores = await search_bm25_chunks( data, modalities, internal_limit, filters, timeout @@ -378,7 +397,7 @@ async def search_chunks( # Sort by score and take top results # If reranking is enabled, fetch more candidates for the reranker to work with sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True) - if settings.ENABLE_RERANKING: + if use_reranking: fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER else: fetch_limit = limit @@ -421,7 +440,7 @@ async def search_chunks( apply_source_boosts(chunks, set()) # Rerank using cross-encoder for better precision - if settings.ENABLE_RERANKING and chunks and query_text.strip(): + if use_reranking and chunks and query_text.strip(): try: chunks = await rerank_chunks( query_text, chunks, model=settings.RERANK_MODEL, top_k=limit @@ -451,7 +470,7 @@ async def search( data: list[extract.DataChunk], modalities: set[str] = set(), filters: SearchFilters = {}, - config: SearchConfig = SearchConfig(), + config: SearchConfig = _DEFAULT_CONFIG, ) -> list[SearchResult]: """ Search across knowledge base using text query and optional files. @@ -472,6 +491,7 @@ async def search( config.limit, filters, config.timeout, + config, ) if settings.ENABLE_SEARCH_SCORING and config.useScores: chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3) diff --git a/src/memory/api/search/types.py b/src/memory/api/search/types.py index 00ee7ff..f3bc63a 100644 --- a/src/memory/api/search/types.py +++ b/src/memory/api/search/types.py @@ -83,6 +83,12 @@ class SearchConfig(BaseModel): previews: bool = False useScores: bool = False + # Optional enhancement flags (None = use global setting from env) + useBm25: Optional[bool] = None + useHyde: Optional[bool] = None + useReranking: Optional[bool] = None + useQueryExpansion: Optional[bool] = None + def model_post_init(self, __context) -> None: # Enforce reasonable limits if self.limit < 1: