Add per-request configuration for search enhancements

Allow callers to enable/disable BM25, HyDE, reranking, and query
expansion on a per-request basis via SearchConfig. When not specified,
falls back to global settings from environment variables.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
mruwnik 2025-12-21 12:38:01 +00:00
parent d9fcfe3878
commit a10f93cb3c
2 changed files with 37 additions and 11 deletions

View File

@ -28,6 +28,9 @@ if settings.ENABLE_RERANKING:
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
# Default config for when none is provided
_DEFAULT_CONFIG = SearchConfig()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Reciprocal Rank Fusion constant (k parameter) # Reciprocal Rank Fusion constant (k parameter)
@ -313,6 +316,7 @@ async def search_chunks(
limit: int = 10, limit: int = 10,
filters: SearchFilters = {}, filters: SearchFilters = {},
timeout: int = 2, timeout: int = 2,
config: SearchConfig = _DEFAULT_CONFIG,
) -> list[Chunk]: ) -> list[Chunk]:
""" """
Search chunks using embedding similarity and optionally BM25. Search chunks using embedding similarity and optionally BM25.
@ -322,7 +326,19 @@ async def search_chunks(
If HyDE is enabled, also generates a hypothetical document from the query If HyDE is enabled, also generates a hypothetical document from the query
and includes it in the embedding search for better semantic matching. and includes it in the embedding search for better semantic matching.
Enhancement flags in config override global settings when set:
- useBm25: Enable BM25 lexical search
- useHyde: Enable HyDE query expansion
- useReranking: Enable cross-encoder reranking
- useQueryExpansion: Enable synonym/abbreviation expansion
""" """
# Resolve enhancement flags: config overrides global settings
use_bm25 = config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
use_hyde = config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
use_reranking = config.useReranking if config.useReranking is not None else settings.ENABLE_RERANKING
use_query_expansion = config.useQueryExpansion if config.useQueryExpansion is not None else True
# Search for more candidates than requested, fuse scores, then return top N # Search for more candidates than requested, fuse scores, then return top N
# This helps find results that rank well in one method but not the other # This helps find results that rank well in one method but not the other
internal_limit = limit * CANDIDATE_MULTIPLIER internal_limit = limit * CANDIDATE_MULTIPLIER
@ -331,17 +347,20 @@ async def search_chunks(
query_text = " ".join( query_text = " ".join(
c for chunk in data for c in chunk.data if isinstance(c, str) c for chunk in data for c in chunk.data if isinstance(c, str)
) )
expanded_query = expand_query(query_text)
if use_query_expansion:
expanded_query = expand_query(query_text)
# If query was expanded, use expanded version for search # If query was expanded, use expanded version for search
if expanded_query != query_text: if expanded_query != query_text:
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'") logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
search_data = [extract.DataChunk(data=[expanded_query])] search_data = [extract.DataChunk(data=[expanded_query])]
else: else:
search_data = list(data) # Copy to avoid modifying original search_data = list(data) # Copy to avoid modifying original
else:
search_data = list(data)
# Apply HyDE expansion if enabled # Apply HyDE expansion if enabled
if settings.ENABLE_HYDE_EXPANSION: if use_hyde:
# Only expand queries with 4+ words (short queries are usually specific enough) # Only expand queries with 4+ words (short queries are usually specific enough)
if len(query_text.split()) >= 4: if len(query_text.split()) >= 4:
try: try:
@ -361,7 +380,7 @@ async def search_chunks(
# Run BM25 search if enabled # Run BM25 search if enabled
bm25_scores: dict[str, float] = {} bm25_scores: dict[str, float] = {}
if settings.ENABLE_BM25_SEARCH: if use_bm25:
try: try:
bm25_scores = await search_bm25_chunks( bm25_scores = await search_bm25_chunks(
data, modalities, internal_limit, filters, timeout data, modalities, internal_limit, filters, timeout
@ -378,7 +397,7 @@ async def search_chunks(
# Sort by score and take top results # Sort by score and take top results
# If reranking is enabled, fetch more candidates for the reranker to work with # If reranking is enabled, fetch more candidates for the reranker to work with
sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True) sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
if settings.ENABLE_RERANKING: if use_reranking:
fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER
else: else:
fetch_limit = limit fetch_limit = limit
@ -421,7 +440,7 @@ async def search_chunks(
apply_source_boosts(chunks, set()) apply_source_boosts(chunks, set())
# Rerank using cross-encoder for better precision # Rerank using cross-encoder for better precision
if settings.ENABLE_RERANKING and chunks and query_text.strip(): if use_reranking and chunks and query_text.strip():
try: try:
chunks = await rerank_chunks( chunks = await rerank_chunks(
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
@ -451,7 +470,7 @@ async def search(
data: list[extract.DataChunk], data: list[extract.DataChunk],
modalities: set[str] = set(), modalities: set[str] = set(),
filters: SearchFilters = {}, filters: SearchFilters = {},
config: SearchConfig = SearchConfig(), config: SearchConfig = _DEFAULT_CONFIG,
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
Search across knowledge base using text query and optional files. Search across knowledge base using text query and optional files.
@ -472,6 +491,7 @@ async def search(
config.limit, config.limit,
filters, filters,
config.timeout, config.timeout,
config,
) )
if settings.ENABLE_SEARCH_SCORING and config.useScores: if settings.ENABLE_SEARCH_SCORING and config.useScores:
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3) chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)

View File

@ -83,6 +83,12 @@ class SearchConfig(BaseModel):
previews: bool = False previews: bool = False
useScores: bool = False useScores: bool = False
# Optional enhancement flags (None = use global setting from env)
useBm25: Optional[bool] = None
useHyde: Optional[bool] = None
useReranking: Optional[bool] = None
useQueryExpansion: Optional[bool] = None
def model_post_init(self, __context) -> None: def model_post_init(self, __context) -> None:
# Enforce reasonable limits # Enforce reasonable limits
if self.limit < 1: if self.limit < 1: