mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Add per-request configuration for search enhancements
Allow callers to enable/disable BM25, HyDE, reranking, and query expansion on a per-request basis via SearchConfig. When not specified, falls back to global settings from environment variables. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
d9fcfe3878
commit
a10f93cb3c
@ -28,6 +28,9 @@ if settings.ENABLE_RERANKING:
|
||||
|
||||
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||
|
||||
# Default config for when none is provided
|
||||
_DEFAULT_CONFIG = SearchConfig()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reciprocal Rank Fusion constant (k parameter)
|
||||
@ -313,6 +316,7 @@ async def search_chunks(
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = {},
|
||||
timeout: int = 2,
|
||||
config: SearchConfig = _DEFAULT_CONFIG,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Search chunks using embedding similarity and optionally BM25.
|
||||
@ -322,7 +326,19 @@ async def search_chunks(
|
||||
|
||||
If HyDE is enabled, also generates a hypothetical document from the query
|
||||
and includes it in the embedding search for better semantic matching.
|
||||
|
||||
Enhancement flags in config override global settings when set:
|
||||
- useBm25: Enable BM25 lexical search
|
||||
- useHyde: Enable HyDE query expansion
|
||||
- useReranking: Enable cross-encoder reranking
|
||||
- useQueryExpansion: Enable synonym/abbreviation expansion
|
||||
"""
|
||||
# Resolve enhancement flags: config overrides global settings
|
||||
use_bm25 = config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
|
||||
use_hyde = config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
|
||||
use_reranking = config.useReranking if config.useReranking is not None else settings.ENABLE_RERANKING
|
||||
use_query_expansion = config.useQueryExpansion if config.useQueryExpansion is not None else True
|
||||
|
||||
# Search for more candidates than requested, fuse scores, then return top N
|
||||
# This helps find results that rank well in one method but not the other
|
||||
internal_limit = limit * CANDIDATE_MULTIPLIER
|
||||
@ -331,17 +347,20 @@ async def search_chunks(
|
||||
query_text = " ".join(
|
||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||
)
|
||||
expanded_query = expand_query(query_text)
|
||||
|
||||
# If query was expanded, use expanded version for search
|
||||
if expanded_query != query_text:
|
||||
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
|
||||
search_data = [extract.DataChunk(data=[expanded_query])]
|
||||
if use_query_expansion:
|
||||
expanded_query = expand_query(query_text)
|
||||
# If query was expanded, use expanded version for search
|
||||
if expanded_query != query_text:
|
||||
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
|
||||
search_data = [extract.DataChunk(data=[expanded_query])]
|
||||
else:
|
||||
search_data = list(data) # Copy to avoid modifying original
|
||||
else:
|
||||
search_data = list(data) # Copy to avoid modifying original
|
||||
search_data = list(data)
|
||||
|
||||
# Apply HyDE expansion if enabled
|
||||
if settings.ENABLE_HYDE_EXPANSION:
|
||||
if use_hyde:
|
||||
# Only expand queries with 4+ words (short queries are usually specific enough)
|
||||
if len(query_text.split()) >= 4:
|
||||
try:
|
||||
@ -361,7 +380,7 @@ async def search_chunks(
|
||||
|
||||
# Run BM25 search if enabled
|
||||
bm25_scores: dict[str, float] = {}
|
||||
if settings.ENABLE_BM25_SEARCH:
|
||||
if use_bm25:
|
||||
try:
|
||||
bm25_scores = await search_bm25_chunks(
|
||||
data, modalities, internal_limit, filters, timeout
|
||||
@ -378,7 +397,7 @@ async def search_chunks(
|
||||
# Sort by score and take top results
|
||||
# If reranking is enabled, fetch more candidates for the reranker to work with
|
||||
sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
|
||||
if settings.ENABLE_RERANKING:
|
||||
if use_reranking:
|
||||
fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER
|
||||
else:
|
||||
fetch_limit = limit
|
||||
@ -421,7 +440,7 @@ async def search_chunks(
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
# Rerank using cross-encoder for better precision
|
||||
if settings.ENABLE_RERANKING and chunks and query_text.strip():
|
||||
if use_reranking and chunks and query_text.strip():
|
||||
try:
|
||||
chunks = await rerank_chunks(
|
||||
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
|
||||
@ -451,7 +470,7 @@ async def search(
|
||||
data: list[extract.DataChunk],
|
||||
modalities: set[str] = set(),
|
||||
filters: SearchFilters = {},
|
||||
config: SearchConfig = SearchConfig(),
|
||||
config: SearchConfig = _DEFAULT_CONFIG,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
@ -472,6 +491,7 @@ async def search(
|
||||
config.limit,
|
||||
filters,
|
||||
config.timeout,
|
||||
config,
|
||||
)
|
||||
if settings.ENABLE_SEARCH_SCORING and config.useScores:
|
||||
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||
|
||||
@ -83,6 +83,12 @@ class SearchConfig(BaseModel):
|
||||
previews: bool = False
|
||||
useScores: bool = False
|
||||
|
||||
# Optional enhancement flags (None = use global setting from env)
|
||||
useBm25: Optional[bool] = None
|
||||
useHyde: Optional[bool] = None
|
||||
useReranking: Optional[bool] = None
|
||||
useQueryExpansion: Optional[bool] = None
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
# Enforce reasonable limits
|
||||
if self.limit < 1:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user