Add comprehensive tests for search improvements

- Add tests for extract_query_terms (stopword filtering, short words)
- Add tests for apply_query_term_boost (boost calculations, edge cases)
- Add tests for deduplicate_by_source (keeps highest per source)
- Add tests for apply_title_boost (title matching with mocked DB)
- Add tests for fuse_scores_rrf (RRF score fusion, ranking behavior)
- Add tests for rerank module (VoyageAI reranker mocking)

Uses pytest.mark.parametrize for concise, data-driven tests.
77 tests total covering all new search functionality.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
mruwnik 2025-12-20 22:26:16 +00:00
parent c6cb793cdf
commit 09215adf9a
3 changed files with 760 additions and 0 deletions

View File

@ -0,0 +1,94 @@
"""
Cross-encoder reranking using VoyageAI's reranker.
Reranking improves search precision by using a cross-encoder model that
sees query and document together, rather than comparing embeddings separately.
"""
import asyncio
import logging
from typing import Optional
import voyageai
from memory.common import settings
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
# VoyageAI reranker models
# rerank-2: More accurate, slower
# rerank-2-lite: Faster, slightly less accurate
DEFAULT_RERANK_MODEL = "rerank-2-lite"
async def rerank_chunks(
query: str,
chunks: list[Chunk],
model: str = DEFAULT_RERANK_MODEL,
top_k: Optional[int] = None,
) -> list[Chunk]:
"""
Rerank chunks using VoyageAI's cross-encoder reranker.
Cross-encoders are more accurate than bi-encoders (embeddings) because
they see query and document together, allowing for deeper semantic matching.
Args:
query: The search query
chunks: List of candidate chunks to rerank
model: VoyageAI reranker model to use
top_k: If set, only return top k results
Returns:
Chunks sorted by reranker relevance score
"""
if not chunks:
return []
if not query.strip():
return chunks
# Extract text content from chunks
documents = []
chunk_map = {} # Map index to chunk
for i, chunk in enumerate(chunks):
content = chunk.content or ""
if not content and hasattr(chunk, "data"):
try:
data = chunk.data
content = "\n".join(str(d) for d in data if isinstance(d, str))
except Exception:
pass
if content:
documents.append(content[:8000]) # VoyageAI has length limits
chunk_map[len(documents) - 1] = chunk
if not documents:
return chunks
try:
vo = voyageai.Client()
result = await asyncio.to_thread(
vo.rerank,
query=query,
documents=documents,
model=model,
top_k=top_k or len(documents),
)
# Map results back to chunks with updated scores
reranked = []
for item in result.results:
chunk = chunk_map.get(item.index)
if chunk:
chunk.relevance_score = item.relevance_score
reranked.append(chunk)
return reranked
except Exception as e:
logger.warning(f"Reranking failed, returning original order: {e}")
return chunks

View File

@ -0,0 +1,313 @@
"""
Tests for the rerank module (VoyageAI cross-encoder reranking).
"""
import pytest
from unittest.mock import MagicMock, patch
from memory.api.search.rerank import rerank_chunks, DEFAULT_RERANK_MODEL
class MockRerankResult:
"""Mock for VoyageAI rerank result item."""
def __init__(self, index: int, relevance_score: float):
self.index = index
self.relevance_score = relevance_score
class MockRerankResponse:
"""Mock for VoyageAI rerank response."""
def __init__(self, results: list[MockRerankResult]):
self.results = results
def _make_chunk(content: str = "test content", score: float = 0.5):
"""Create a mock chunk."""
chunk = MagicMock()
chunk.content = content
chunk.relevance_score = score
chunk.data = None
return chunk
# ============================================================================
# Basic reranking tests
# ============================================================================
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_basic(mock_voyageai):
"""Should rerank chunks using VoyageAI."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=1, relevance_score=0.9),
MockRerankResult(index=0, relevance_score=0.7),
])
chunks = [_make_chunk("first", 0.5), _make_chunk("second", 0.6)]
result = await rerank_chunks("test query", chunks)
assert len(result) == 2
assert result[0].relevance_score == 0.9
assert result[1].relevance_score == 0.7
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_reorders_correctly(mock_voyageai):
"""Should correctly reorder multiple chunks."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=2, relevance_score=0.95),
MockRerankResult(index=0, relevance_score=0.85),
MockRerankResult(index=1, relevance_score=0.75),
])
chunk_a = _make_chunk("chunk a", 0.5)
chunk_b = _make_chunk("chunk b", 0.6)
chunk_c = _make_chunk("chunk c", 0.4)
result = await rerank_chunks("query", [chunk_a, chunk_b, chunk_c])
assert result[0] is chunk_c
assert result[1] is chunk_a
assert result[2] is chunk_b
# ============================================================================
# Empty/edge case tests
# ============================================================================
@pytest.mark.asyncio
@pytest.mark.parametrize("query", ["", " ", "\t\n"])
async def test_rerank_chunks_empty_query(query):
"""Should return original chunks for empty/whitespace query."""
chunks = [_make_chunk("test", 0.5)]
result = await rerank_chunks(query, chunks)
assert result == chunks
@pytest.mark.asyncio
async def test_rerank_chunks_empty_chunks():
"""Should return empty list for empty input."""
result = await rerank_chunks("test query", [])
assert result == []
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_all_empty_content(mock_voyageai):
"""Should return original chunks if all have empty content."""
chunk1 = MagicMock()
chunk1.content = ""
chunk1.data = None
chunk1.relevance_score = 0.5
chunk2 = MagicMock()
chunk2.content = None
chunk2.data = []
chunk2.relevance_score = 0.6
chunks = [chunk1, chunk2]
result = await rerank_chunks("query", chunks)
assert result == chunks
mock_voyageai.Client.assert_not_called()
# ============================================================================
# Model and parameter tests
# ============================================================================
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["rerank-2", "rerank-2-lite", "custom-model"])
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_uses_specified_model(mock_voyageai, model):
"""Should use specified model."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.8),
])
await rerank_chunks("query", [_make_chunk()], model=model)
call_kwargs = mock_client.rerank.call_args[1]
assert call_kwargs["model"] == model
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_uses_default_model(mock_voyageai):
"""Should use default model when not specified."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.8),
])
await rerank_chunks("query", [_make_chunk()])
call_kwargs = mock_client.rerank.call_args[1]
assert call_kwargs["model"] == DEFAULT_RERANK_MODEL
@pytest.mark.asyncio
@pytest.mark.parametrize("top_k", [1, 5, 10])
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_respects_top_k(mock_voyageai, top_k):
"""Should pass top_k to VoyageAI."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.8),
])
chunks = [_make_chunk() for _ in range(10)]
await rerank_chunks("query", chunks, top_k=top_k)
call_kwargs = mock_client.rerank.call_args[1]
assert call_kwargs["top_k"] == top_k
# ============================================================================
# Content handling tests
# ============================================================================
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_skips_none_content(mock_voyageai):
"""Should skip chunks with None content."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.8),
])
chunk_with = _make_chunk("has content", 0.5)
chunk_without = _make_chunk(None, 0.5)
chunk_without.content = None
await rerank_chunks("query", [chunk_with, chunk_without])
call_kwargs = mock_client.rerank.call_args[1]
assert len(call_kwargs["documents"]) == 1
assert call_kwargs["documents"][0] == "has content"
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_truncates_long_content(mock_voyageai):
"""Should truncate content to 8000 characters."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.8),
])
long_content = "x" * 10000
await rerank_chunks("query", [_make_chunk(long_content)])
call_kwargs = mock_client.rerank.call_args[1]
assert len(call_kwargs["documents"][0]) == 8000
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_uses_data_fallback(mock_voyageai):
"""Should fall back to data attribute if content is empty."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.8),
])
chunk = MagicMock()
chunk.content = ""
chunk.data = ["text from data", 123, "more text"]
chunk.relevance_score = 0.5
await rerank_chunks("query", [chunk])
call_kwargs = mock_client.rerank.call_args[1]
assert "text from data" in call_kwargs["documents"][0]
assert "more text" in call_kwargs["documents"][0]
# ============================================================================
# Error handling tests
# ============================================================================
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_handles_api_error(mock_voyageai):
"""Should return original chunks on API error."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.side_effect = Exception("API error")
chunks = [_make_chunk("test", 0.5)]
result = await rerank_chunks("query", chunks)
assert result == chunks
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_handles_missing_index(mock_voyageai):
"""Should handle missing indices gracefully."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.8),
MockRerankResult(index=99, relevance_score=0.7), # Invalid index
])
result = await rerank_chunks("query", [_make_chunk()])
assert len(result) == 1
# ============================================================================
# Object preservation tests
# ============================================================================
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_preserves_objects(mock_voyageai):
"""Should return the same chunk objects, not copies."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.8),
])
chunk = _make_chunk("test", 0.5)
result = await rerank_chunks("query", [chunk])
assert result[0] is chunk
@pytest.mark.asyncio
@patch("memory.api.search.rerank.voyageai")
async def test_rerank_chunks_updates_scores(mock_voyageai):
"""Should update chunk relevance_score from reranker."""
mock_client = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock_client.rerank.return_value = MockRerankResponse([
MockRerankResult(index=0, relevance_score=0.95),
])
chunk = _make_chunk("test", 0.5)
result = await rerank_chunks("query", [chunk])
assert result[0].relevance_score == 0.95

View File

@ -0,0 +1,353 @@
"""
Tests for search module functions including RRF fusion, query term boosting,
title boosting, and source deduplication.
"""
import pytest
from unittest.mock import MagicMock, patch
from memory.api.search.search import (
extract_query_terms,
apply_query_term_boost,
deduplicate_by_source,
apply_title_boost,
fuse_scores_rrf,
STOPWORDS,
QUERY_TERM_BOOST,
TITLE_MATCH_BOOST,
RRF_K,
)
# ============================================================================
# extract_query_terms tests
# ============================================================================
@pytest.mark.parametrize(
"query,expected",
[
("machine learning algorithms", {"machine", "learning", "algorithms"}),
("MACHINE Learning ALGORITHMS", {"machine", "learning", "algorithms"}),
("", set()),
("the is a an of to", set()), # Only stopwords
],
)
def test_extract_query_terms_basic(query, expected):
"""Should extract meaningful terms, lowercase them, and filter stopwords."""
assert extract_query_terms(query) == expected
@pytest.mark.parametrize(
"query,must_include,must_exclude",
[
(
"the quick brown fox jumps with the lazy dog",
{"quick", "brown", "jumps", "lazy", "fox", "dog"},
{"the", "with"},
),
(
"what is the best approach for neural networks",
{"best", "approach", "neural", "networks"},
{"what", "the", "for"},
),
],
)
def test_extract_query_terms_filtering(query, must_include, must_exclude):
"""Should filter stopwords while keeping meaningful terms."""
terms = extract_query_terms(query)
for term in must_include:
assert term in terms, f"'{term}' should be in terms"
for term in must_exclude:
assert term not in terms, f"'{term}' should not be in terms"
@pytest.mark.parametrize(
"query,excluded",
[
("AI is a new ML model", {"ai", "is", "a", "ml"}), # Short words filtered
],
)
def test_extract_query_terms_short_words(query, excluded):
"""Should filter words with 2 or fewer characters."""
terms = extract_query_terms(query)
for term in excluded:
assert term not in terms
@pytest.mark.parametrize(
"word",
["the", "is", "are", "was", "were", "be", "been", "have", "has", "had",
"do", "does", "did", "to", "of", "in", "for", "on", "with", "at", "by"],
)
def test_common_stopwords_in_set(word):
"""Verify common stopwords are in the STOPWORDS set."""
assert word in STOPWORDS
# ============================================================================
# apply_query_term_boost tests
# ============================================================================
def _make_chunk(content: str, source_id: int = 1, score: float = 0.5):
"""Create a mock chunk with given content and score."""
chunk = MagicMock()
chunk.content = content
chunk.source_id = source_id
chunk.relevance_score = score
return chunk
@pytest.mark.parametrize(
"content,query_terms,initial_score,expected_boost_fraction",
[
("machine learning is powerful", {"machine", "learning"}, 0.5, 1.0), # Both match
("machine vision systems", {"machine", "learning"}, 0.5, 0.5), # One of two
("deep neural networks", {"machine", "learning"}, 0.5, 0.0), # No match
("MACHINE Learning AlGoRiThMs", {"machine", "learning", "algorithms"}, 0.5, 1.0), # Case insensitive
],
)
def test_apply_query_term_boost(content, query_terms, initial_score, expected_boost_fraction):
"""Should boost chunks based on query term matches."""
chunks = [_make_chunk(content, score=initial_score)]
apply_query_term_boost(chunks, query_terms)
expected = initial_score + QUERY_TERM_BOOST * expected_boost_fraction
assert chunks[0].relevance_score == pytest.approx(expected)
def test_apply_query_term_boost_empty_inputs():
"""Should handle empty query_terms or chunks."""
chunks = [_make_chunk("machine learning", score=0.5)]
apply_query_term_boost(chunks, set())
assert chunks[0].relevance_score == 0.5
apply_query_term_boost([], {"machine"}) # Should not raise
def test_apply_query_term_boost_none_values():
"""Should handle None content and relevance_score."""
chunk_none_content = MagicMock()
chunk_none_content.content = None
chunk_none_content.relevance_score = 0.5
apply_query_term_boost([chunk_none_content], {"machine"})
assert chunk_none_content.relevance_score == 0.5
chunk_none_score = MagicMock()
chunk_none_score.content = "machine learning"
chunk_none_score.relevance_score = None
apply_query_term_boost([chunk_none_score], {"machine", "learning"})
assert chunk_none_score.relevance_score == pytest.approx(QUERY_TERM_BOOST)
def test_apply_query_term_boost_multiple_chunks():
"""Should boost each chunk independently."""
chunks = [
_make_chunk("machine learning", score=0.5),
_make_chunk("deep networks", score=0.6),
_make_chunk("machine vision", score=0.4),
]
query_terms = {"machine", "learning"}
apply_query_term_boost(chunks, query_terms)
assert chunks[0].relevance_score == pytest.approx(0.5 + QUERY_TERM_BOOST)
assert chunks[1].relevance_score == 0.6 # No match
assert chunks[2].relevance_score == pytest.approx(0.4 + QUERY_TERM_BOOST * 0.5)
# ============================================================================
# deduplicate_by_source tests
# ============================================================================
def _make_source_chunk(source_id: int, score: float):
"""Create a mock chunk with given source_id and score."""
chunk = MagicMock()
chunk.source_id = source_id
chunk.relevance_score = score
return chunk
@pytest.mark.parametrize(
"chunks_data,expected_count,expected_scores",
[
# Multiple chunks per source - keep highest
([(1, 0.5), (1, 0.8), (1, 0.3), (2, 0.6)], 2, {1: 0.8, 2: 0.6}),
# Single chunk per source - keep all
([(1, 0.5), (2, 0.6), (3, 0.7)], 3, {1: 0.5, 2: 0.6, 3: 0.7}),
# Empty list
([], 0, {}),
],
)
def test_deduplicate_by_source(chunks_data, expected_count, expected_scores):
"""Should keep only highest scoring chunk per source."""
chunks = [_make_source_chunk(sid, score) for sid, score in chunks_data]
result = deduplicate_by_source(chunks)
assert len(result) == expected_count
for chunk in result:
assert chunk.relevance_score == expected_scores[chunk.source_id]
def test_deduplicate_by_source_preserves_objects():
"""Should return the actual chunk objects, not copies."""
chunk1 = _make_source_chunk(1, 0.5)
chunk2 = _make_source_chunk(1, 0.8)
result = deduplicate_by_source([chunk1, chunk2])
assert result[0] is chunk2
def test_deduplicate_by_source_none_scores():
"""Should handle None relevance_score as 0."""
chunk1 = _make_source_chunk(1, None)
chunk2 = _make_source_chunk(1, 0.5)
result = deduplicate_by_source([chunk1, chunk2])
assert result[0].relevance_score == 0.5
# ============================================================================
# apply_title_boost tests
# ============================================================================
def _make_title_chunk(source_id: int, score: float = 0.5):
"""Create a mock chunk for title boost tests."""
chunk = MagicMock()
chunk.source_id = source_id
chunk.relevance_score = score
return chunk
@pytest.mark.parametrize(
"title,query_terms,initial_score,expected_boost_fraction",
[
("Machine Learning Tutorial", {"machine", "learning"}, 0.5, 1.0),
("Machine Vision Systems", {"machine", "learning"}, 0.5, 0.5),
("Deep Neural Networks", {"machine", "learning"}, 0.5, 0.0),
("MACHINE LEARNING Tutorial", {"machine", "learning"}, 0.5, 1.0), # Case insensitive
],
)
@patch("memory.api.search.search.make_session")
def test_apply_title_boost(mock_make_session, title, query_terms, initial_score, expected_boost_fraction):
"""Should boost chunks when title matches query terms."""
mock_session = MagicMock()
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
mock_source = MagicMock()
mock_source.id = 1
mock_source.title = title
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_title_chunk(1, initial_score)]
apply_title_boost(chunks, query_terms)
expected = initial_score + TITLE_MATCH_BOOST * expected_boost_fraction
assert chunks[0].relevance_score == pytest.approx(expected)
def test_apply_title_boost_empty_inputs():
"""Should not modify chunks if query_terms or chunks is empty."""
chunks = [_make_title_chunk(1, 0.5)]
apply_title_boost(chunks, set())
assert chunks[0].relevance_score == 0.5
apply_title_boost([], {"machine"}) # Should not raise
@patch("memory.api.search.search.make_session")
def test_apply_title_boost_none_title(mock_make_session):
"""Should handle sources with None or missing title."""
mock_session = MagicMock()
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
# Source with None title
mock_source = MagicMock()
mock_source.id = 1
mock_source.title = None
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_title_chunk(1, 0.5)]
apply_title_boost(chunks, {"machine"})
assert chunks[0].relevance_score == 0.5
# ============================================================================
# fuse_scores_rrf tests
# ============================================================================
@pytest.mark.parametrize(
"embedding_scores,bm25_scores,expected_key,expected_score",
[
# Both sources have same ranking
({"a": 0.9, "b": 0.7}, {"a": 0.8, "b": 0.6}, "a", 2 / (RRF_K + 1)),
# Item only in embeddings
({"a": 0.9, "b": 0.7}, {"a": 0.8}, "b", 1 / (RRF_K + 2)),
# Item only in BM25
({"a": 0.9}, {"a": 0.8, "b": 0.7}, "b", 1 / (RRF_K + 2)),
# Single item in both
({"a": 0.9}, {"a": 0.8}, "a", 2 / (RRF_K + 1)),
],
)
def test_fuse_scores_rrf_basic(embedding_scores, bm25_scores, expected_key, expected_score):
"""Should compute RRF scores correctly."""
result = fuse_scores_rrf(embedding_scores, bm25_scores)
assert result[expected_key] == pytest.approx(expected_score)
def test_fuse_scores_rrf_different_rankings():
"""Should handle items ranked differently in each source."""
embedding_scores = {"a": 0.9, "b": 0.5} # a=1, b=2
bm25_scores = {"a": 0.3, "b": 0.8} # b=1, a=2
result = fuse_scores_rrf(embedding_scores, bm25_scores)
# Both should have same RRF score (1/61 + 1/62)
expected = 1 / (RRF_K + 1) + 1 / (RRF_K + 2)
assert result["a"] == pytest.approx(expected)
assert result["b"] == pytest.approx(expected)
@pytest.mark.parametrize(
"embedding_scores,bm25_scores,expected_len",
[
({}, {}, 0),
({}, {"a": 0.8, "b": 0.6}, 2),
({"a": 0.9, "b": 0.7}, {}, 2),
],
)
def test_fuse_scores_rrf_empty_inputs(embedding_scores, bm25_scores, expected_len):
"""Should handle empty inputs gracefully."""
result = fuse_scores_rrf(embedding_scores, bm25_scores)
assert len(result) == expected_len
def test_fuse_scores_rrf_many_items():
"""Should handle many items correctly."""
embedding_scores = {str(i): 1.0 - i * 0.01 for i in range(100)}
bm25_scores = {str(i): 1.0 - i * 0.01 for i in range(100)}
result = fuse_scores_rrf(embedding_scores, bm25_scores)
assert len(result) == 100
assert result["0"] > result["99"] # First should have highest score
def test_fuse_scores_rrf_only_ranks_matter():
"""RRF should only care about ranks, not score magnitudes."""
# Same ranking, different score scales
result1 = fuse_scores_rrf(
{"a": 0.99, "b": 0.98, "c": 0.97},
{"a": 100, "b": 50, "c": 1},
)
result2 = fuse_scores_rrf(
{"a": 0.5, "b": 0.4, "c": 0.3},
{"a": 0.9, "b": 0.8, "c": 0.7},
)
# RRF scores should be identical since rankings are the same
assert result1["a"] == pytest.approx(result2["a"])
assert result1["b"] == pytest.approx(result2["b"])
assert result1["c"] == pytest.approx(result2["c"])