mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 17:22:58 +01:00
Add per-request configuration for search enhancements
Allow callers to enable/disable BM25, HyDE, reranking, and query expansion on a per-request basis via SearchConfig. When not specified, falls back to global settings from environment variables. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
d9fcfe3878
commit
a10f93cb3c
@ -28,6 +28,9 @@ if settings.ENABLE_RERANKING:
|
|||||||
|
|
||||||
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||||
|
|
||||||
|
# Default config for when none is provided
|
||||||
|
_DEFAULT_CONFIG = SearchConfig()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Reciprocal Rank Fusion constant (k parameter)
|
# Reciprocal Rank Fusion constant (k parameter)
|
||||||
@ -313,6 +316,7 @@ async def search_chunks(
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
filters: SearchFilters = {},
|
filters: SearchFilters = {},
|
||||||
timeout: int = 2,
|
timeout: int = 2,
|
||||||
|
config: SearchConfig = _DEFAULT_CONFIG,
|
||||||
) -> list[Chunk]:
|
) -> list[Chunk]:
|
||||||
"""
|
"""
|
||||||
Search chunks using embedding similarity and optionally BM25.
|
Search chunks using embedding similarity and optionally BM25.
|
||||||
@ -322,7 +326,19 @@ async def search_chunks(
|
|||||||
|
|
||||||
If HyDE is enabled, also generates a hypothetical document from the query
|
If HyDE is enabled, also generates a hypothetical document from the query
|
||||||
and includes it in the embedding search for better semantic matching.
|
and includes it in the embedding search for better semantic matching.
|
||||||
|
|
||||||
|
Enhancement flags in config override global settings when set:
|
||||||
|
- useBm25: Enable BM25 lexical search
|
||||||
|
- useHyde: Enable HyDE query expansion
|
||||||
|
- useReranking: Enable cross-encoder reranking
|
||||||
|
- useQueryExpansion: Enable synonym/abbreviation expansion
|
||||||
"""
|
"""
|
||||||
|
# Resolve enhancement flags: config overrides global settings
|
||||||
|
use_bm25 = config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
|
||||||
|
use_hyde = config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
|
||||||
|
use_reranking = config.useReranking if config.useReranking is not None else settings.ENABLE_RERANKING
|
||||||
|
use_query_expansion = config.useQueryExpansion if config.useQueryExpansion is not None else True
|
||||||
|
|
||||||
# Search for more candidates than requested, fuse scores, then return top N
|
# Search for more candidates than requested, fuse scores, then return top N
|
||||||
# This helps find results that rank well in one method but not the other
|
# This helps find results that rank well in one method but not the other
|
||||||
internal_limit = limit * CANDIDATE_MULTIPLIER
|
internal_limit = limit * CANDIDATE_MULTIPLIER
|
||||||
@ -331,17 +347,20 @@ async def search_chunks(
|
|||||||
query_text = " ".join(
|
query_text = " ".join(
|
||||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||||
)
|
)
|
||||||
expanded_query = expand_query(query_text)
|
|
||||||
|
|
||||||
# If query was expanded, use expanded version for search
|
if use_query_expansion:
|
||||||
if expanded_query != query_text:
|
expanded_query = expand_query(query_text)
|
||||||
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
|
# If query was expanded, use expanded version for search
|
||||||
search_data = [extract.DataChunk(data=[expanded_query])]
|
if expanded_query != query_text:
|
||||||
|
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
|
||||||
|
search_data = [extract.DataChunk(data=[expanded_query])]
|
||||||
|
else:
|
||||||
|
search_data = list(data) # Copy to avoid modifying original
|
||||||
else:
|
else:
|
||||||
search_data = list(data) # Copy to avoid modifying original
|
search_data = list(data)
|
||||||
|
|
||||||
# Apply HyDE expansion if enabled
|
# Apply HyDE expansion if enabled
|
||||||
if settings.ENABLE_HYDE_EXPANSION:
|
if use_hyde:
|
||||||
# Only expand queries with 4+ words (short queries are usually specific enough)
|
# Only expand queries with 4+ words (short queries are usually specific enough)
|
||||||
if len(query_text.split()) >= 4:
|
if len(query_text.split()) >= 4:
|
||||||
try:
|
try:
|
||||||
@ -361,7 +380,7 @@ async def search_chunks(
|
|||||||
|
|
||||||
# Run BM25 search if enabled
|
# Run BM25 search if enabled
|
||||||
bm25_scores: dict[str, float] = {}
|
bm25_scores: dict[str, float] = {}
|
||||||
if settings.ENABLE_BM25_SEARCH:
|
if use_bm25:
|
||||||
try:
|
try:
|
||||||
bm25_scores = await search_bm25_chunks(
|
bm25_scores = await search_bm25_chunks(
|
||||||
data, modalities, internal_limit, filters, timeout
|
data, modalities, internal_limit, filters, timeout
|
||||||
@ -378,7 +397,7 @@ async def search_chunks(
|
|||||||
# Sort by score and take top results
|
# Sort by score and take top results
|
||||||
# If reranking is enabled, fetch more candidates for the reranker to work with
|
# If reranking is enabled, fetch more candidates for the reranker to work with
|
||||||
sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
|
sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
|
||||||
if settings.ENABLE_RERANKING:
|
if use_reranking:
|
||||||
fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER
|
fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER
|
||||||
else:
|
else:
|
||||||
fetch_limit = limit
|
fetch_limit = limit
|
||||||
@ -421,7 +440,7 @@ async def search_chunks(
|
|||||||
apply_source_boosts(chunks, set())
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
# Rerank using cross-encoder for better precision
|
# Rerank using cross-encoder for better precision
|
||||||
if settings.ENABLE_RERANKING and chunks and query_text.strip():
|
if use_reranking and chunks and query_text.strip():
|
||||||
try:
|
try:
|
||||||
chunks = await rerank_chunks(
|
chunks = await rerank_chunks(
|
||||||
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
|
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
|
||||||
@ -451,7 +470,7 @@ async def search(
|
|||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
filters: SearchFilters = {},
|
filters: SearchFilters = {},
|
||||||
config: SearchConfig = SearchConfig(),
|
config: SearchConfig = _DEFAULT_CONFIG,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search across knowledge base using text query and optional files.
|
Search across knowledge base using text query and optional files.
|
||||||
@ -472,6 +491,7 @@ async def search(
|
|||||||
config.limit,
|
config.limit,
|
||||||
filters,
|
filters,
|
||||||
config.timeout,
|
config.timeout,
|
||||||
|
config,
|
||||||
)
|
)
|
||||||
if settings.ENABLE_SEARCH_SCORING and config.useScores:
|
if settings.ENABLE_SEARCH_SCORING and config.useScores:
|
||||||
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||||
|
|||||||
@ -83,6 +83,12 @@ class SearchConfig(BaseModel):
|
|||||||
previews: bool = False
|
previews: bool = False
|
||||||
useScores: bool = False
|
useScores: bool = False
|
||||||
|
|
||||||
|
# Optional enhancement flags (None = use global setting from env)
|
||||||
|
useBm25: Optional[bool] = None
|
||||||
|
useHyde: Optional[bool] = None
|
||||||
|
useReranking: Optional[bool] = None
|
||||||
|
useQueryExpansion: Optional[bool] = None
|
||||||
|
|
||||||
def model_post_init(self, __context) -> None:
|
def model_post_init(self, __context) -> None:
|
||||||
# Enforce reasonable limits
|
# Enforce reasonable limits
|
||||||
if self.limit < 1:
|
if self.limit < 1:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user