From 09215adf9a4501043481adaea4f3685afe064d08 Mon Sep 17 00:00:00 2001 From: mruwnik Date: Sat, 20 Dec 2025 22:26:16 +0000 Subject: [PATCH] Add comprehensive tests for search improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add tests for extract_query_terms (stopword filtering, short words) - Add tests for apply_query_term_boost (boost calculations, edge cases) - Add tests for deduplicate_by_source (keeps highest per source) - Add tests for apply_title_boost (title matching with mocked DB) - Add tests for fuse_scores_rrf (RRF score fusion, ranking behavior) - Add tests for rerank module (VoyageAI reranker mocking) Uses pytest.mark.parametrize for concise, data-driven tests. 77 tests total covering all new search functionality. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/memory/api/search/rerank.py | 94 +++++++ tests/memory/api/search/test_rerank.py | 313 ++++++++++++++++++++++ tests/memory/api/search/test_search.py | 353 +++++++++++++++++++++++++ 3 files changed, 760 insertions(+) create mode 100644 src/memory/api/search/rerank.py create mode 100644 tests/memory/api/search/test_rerank.py create mode 100644 tests/memory/api/search/test_search.py diff --git a/src/memory/api/search/rerank.py b/src/memory/api/search/rerank.py new file mode 100644 index 0000000..d5d68c8 --- /dev/null +++ b/src/memory/api/search/rerank.py @@ -0,0 +1,94 @@ +""" +Cross-encoder reranking using VoyageAI's reranker. + +Reranking improves search precision by using a cross-encoder model that +sees query and document together, rather than comparing embeddings separately. +""" + +import asyncio +import logging +from typing import Optional + +import voyageai + +from memory.common import settings +from memory.common.db.models import Chunk + +logger = logging.getLogger(__name__) + +# VoyageAI reranker models +# rerank-2: More accurate, slower +# rerank-2-lite: Faster, slightly less accurate +DEFAULT_RERANK_MODEL = "rerank-2-lite" + + +async def rerank_chunks( + query: str, + chunks: list[Chunk], + model: str = DEFAULT_RERANK_MODEL, + top_k: Optional[int] = None, +) -> list[Chunk]: + """ + Rerank chunks using VoyageAI's cross-encoder reranker. + + Cross-encoders are more accurate than bi-encoders (embeddings) because + they see query and document together, allowing for deeper semantic matching. + + Args: + query: The search query + chunks: List of candidate chunks to rerank + model: VoyageAI reranker model to use + top_k: If set, only return top k results + + Returns: + Chunks sorted by reranker relevance score + """ + if not chunks: + return [] + + if not query.strip(): + return chunks + + # Extract text content from chunks + documents = [] + chunk_map = {} # Map index to chunk + + for i, chunk in enumerate(chunks): + content = chunk.content or "" + if not content and hasattr(chunk, "data"): + try: + data = chunk.data + content = "\n".join(str(d) for d in data if isinstance(d, str)) + except Exception: + pass + + if content: + documents.append(content[:8000]) # VoyageAI has length limits + chunk_map[len(documents) - 1] = chunk + + if not documents: + return chunks + + try: + vo = voyageai.Client() + result = await asyncio.to_thread( + vo.rerank, + query=query, + documents=documents, + model=model, + top_k=top_k or len(documents), + ) + + # Map results back to chunks with updated scores + reranked = [] + for item in result.results: + chunk = chunk_map.get(item.index) + if chunk: + chunk.relevance_score = item.relevance_score + reranked.append(chunk) + + return reranked + + except Exception as e: + logger.warning(f"Reranking failed, returning original order: {e}") + return chunks diff --git a/tests/memory/api/search/test_rerank.py b/tests/memory/api/search/test_rerank.py new file mode 100644 index 0000000..f35f8c0 --- /dev/null +++ b/tests/memory/api/search/test_rerank.py @@ -0,0 +1,313 @@ +""" +Tests for the rerank module (VoyageAI cross-encoder reranking). +""" + +import pytest +from unittest.mock import MagicMock, patch + +from memory.api.search.rerank import rerank_chunks, DEFAULT_RERANK_MODEL + + +class MockRerankResult: + """Mock for VoyageAI rerank result item.""" + + def __init__(self, index: int, relevance_score: float): + self.index = index + self.relevance_score = relevance_score + + +class MockRerankResponse: + """Mock for VoyageAI rerank response.""" + + def __init__(self, results: list[MockRerankResult]): + self.results = results + + +def _make_chunk(content: str = "test content", score: float = 0.5): + """Create a mock chunk.""" + chunk = MagicMock() + chunk.content = content + chunk.relevance_score = score + chunk.data = None + return chunk + + +# ============================================================================ +# Basic reranking tests +# ============================================================================ + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_basic(mock_voyageai): + """Should rerank chunks using VoyageAI.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=1, relevance_score=0.9), + MockRerankResult(index=0, relevance_score=0.7), + ]) + + chunks = [_make_chunk("first", 0.5), _make_chunk("second", 0.6)] + result = await rerank_chunks("test query", chunks) + + assert len(result) == 2 + assert result[0].relevance_score == 0.9 + assert result[1].relevance_score == 0.7 + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_reorders_correctly(mock_voyageai): + """Should correctly reorder multiple chunks.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=2, relevance_score=0.95), + MockRerankResult(index=0, relevance_score=0.85), + MockRerankResult(index=1, relevance_score=0.75), + ]) + + chunk_a = _make_chunk("chunk a", 0.5) + chunk_b = _make_chunk("chunk b", 0.6) + chunk_c = _make_chunk("chunk c", 0.4) + + result = await rerank_chunks("query", [chunk_a, chunk_b, chunk_c]) + + assert result[0] is chunk_c + assert result[1] is chunk_a + assert result[2] is chunk_b + + +# ============================================================================ +# Empty/edge case tests +# ============================================================================ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("query", ["", " ", "\t\n"]) +async def test_rerank_chunks_empty_query(query): + """Should return original chunks for empty/whitespace query.""" + chunks = [_make_chunk("test", 0.5)] + result = await rerank_chunks(query, chunks) + assert result == chunks + + +@pytest.mark.asyncio +async def test_rerank_chunks_empty_chunks(): + """Should return empty list for empty input.""" + result = await rerank_chunks("test query", []) + assert result == [] + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_all_empty_content(mock_voyageai): + """Should return original chunks if all have empty content.""" + chunk1 = MagicMock() + chunk1.content = "" + chunk1.data = None + chunk1.relevance_score = 0.5 + + chunk2 = MagicMock() + chunk2.content = None + chunk2.data = [] + chunk2.relevance_score = 0.6 + + chunks = [chunk1, chunk2] + result = await rerank_chunks("query", chunks) + + assert result == chunks + mock_voyageai.Client.assert_not_called() + + +# ============================================================================ +# Model and parameter tests +# ============================================================================ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model", ["rerank-2", "rerank-2-lite", "custom-model"]) +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_uses_specified_model(mock_voyageai, model): + """Should use specified model.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.8), + ]) + + await rerank_chunks("query", [_make_chunk()], model=model) + + call_kwargs = mock_client.rerank.call_args[1] + assert call_kwargs["model"] == model + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_uses_default_model(mock_voyageai): + """Should use default model when not specified.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.8), + ]) + + await rerank_chunks("query", [_make_chunk()]) + + call_kwargs = mock_client.rerank.call_args[1] + assert call_kwargs["model"] == DEFAULT_RERANK_MODEL + + +@pytest.mark.asyncio +@pytest.mark.parametrize("top_k", [1, 5, 10]) +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_respects_top_k(mock_voyageai, top_k): + """Should pass top_k to VoyageAI.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.8), + ]) + + chunks = [_make_chunk() for _ in range(10)] + await rerank_chunks("query", chunks, top_k=top_k) + + call_kwargs = mock_client.rerank.call_args[1] + assert call_kwargs["top_k"] == top_k + + +# ============================================================================ +# Content handling tests +# ============================================================================ + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_skips_none_content(mock_voyageai): + """Should skip chunks with None content.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.8), + ]) + + chunk_with = _make_chunk("has content", 0.5) + chunk_without = _make_chunk(None, 0.5) + chunk_without.content = None + + await rerank_chunks("query", [chunk_with, chunk_without]) + + call_kwargs = mock_client.rerank.call_args[1] + assert len(call_kwargs["documents"]) == 1 + assert call_kwargs["documents"][0] == "has content" + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_truncates_long_content(mock_voyageai): + """Should truncate content to 8000 characters.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.8), + ]) + + long_content = "x" * 10000 + await rerank_chunks("query", [_make_chunk(long_content)]) + + call_kwargs = mock_client.rerank.call_args[1] + assert len(call_kwargs["documents"][0]) == 8000 + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_uses_data_fallback(mock_voyageai): + """Should fall back to data attribute if content is empty.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.8), + ]) + + chunk = MagicMock() + chunk.content = "" + chunk.data = ["text from data", 123, "more text"] + chunk.relevance_score = 0.5 + + await rerank_chunks("query", [chunk]) + + call_kwargs = mock_client.rerank.call_args[1] + assert "text from data" in call_kwargs["documents"][0] + assert "more text" in call_kwargs["documents"][0] + + +# ============================================================================ +# Error handling tests +# ============================================================================ + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_handles_api_error(mock_voyageai): + """Should return original chunks on API error.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.side_effect = Exception("API error") + + chunks = [_make_chunk("test", 0.5)] + result = await rerank_chunks("query", chunks) + + assert result == chunks + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_handles_missing_index(mock_voyageai): + """Should handle missing indices gracefully.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.8), + MockRerankResult(index=99, relevance_score=0.7), # Invalid index + ]) + + result = await rerank_chunks("query", [_make_chunk()]) + assert len(result) == 1 + + +# ============================================================================ +# Object preservation tests +# ============================================================================ + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_preserves_objects(mock_voyageai): + """Should return the same chunk objects, not copies.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.8), + ]) + + chunk = _make_chunk("test", 0.5) + result = await rerank_chunks("query", [chunk]) + + assert result[0] is chunk + + +@pytest.mark.asyncio +@patch("memory.api.search.rerank.voyageai") +async def test_rerank_chunks_updates_scores(mock_voyageai): + """Should update chunk relevance_score from reranker.""" + mock_client = MagicMock() + mock_voyageai.Client.return_value = mock_client + mock_client.rerank.return_value = MockRerankResponse([ + MockRerankResult(index=0, relevance_score=0.95), + ]) + + chunk = _make_chunk("test", 0.5) + result = await rerank_chunks("query", [chunk]) + + assert result[0].relevance_score == 0.95 diff --git a/tests/memory/api/search/test_search.py b/tests/memory/api/search/test_search.py new file mode 100644 index 0000000..15b2359 --- /dev/null +++ b/tests/memory/api/search/test_search.py @@ -0,0 +1,353 @@ +""" +Tests for search module functions including RRF fusion, query term boosting, +title boosting, and source deduplication. +""" + +import pytest +from unittest.mock import MagicMock, patch + +from memory.api.search.search import ( + extract_query_terms, + apply_query_term_boost, + deduplicate_by_source, + apply_title_boost, + fuse_scores_rrf, + STOPWORDS, + QUERY_TERM_BOOST, + TITLE_MATCH_BOOST, + RRF_K, +) + + +# ============================================================================ +# extract_query_terms tests +# ============================================================================ + + +@pytest.mark.parametrize( + "query,expected", + [ + ("machine learning algorithms", {"machine", "learning", "algorithms"}), + ("MACHINE Learning ALGORITHMS", {"machine", "learning", "algorithms"}), + ("", set()), + ("the is a an of to", set()), # Only stopwords + ], +) +def test_extract_query_terms_basic(query, expected): + """Should extract meaningful terms, lowercase them, and filter stopwords.""" + assert extract_query_terms(query) == expected + + +@pytest.mark.parametrize( + "query,must_include,must_exclude", + [ + ( + "the quick brown fox jumps with the lazy dog", + {"quick", "brown", "jumps", "lazy", "fox", "dog"}, + {"the", "with"}, + ), + ( + "what is the best approach for neural networks", + {"best", "approach", "neural", "networks"}, + {"what", "the", "for"}, + ), + ], +) +def test_extract_query_terms_filtering(query, must_include, must_exclude): + """Should filter stopwords while keeping meaningful terms.""" + terms = extract_query_terms(query) + for term in must_include: + assert term in terms, f"'{term}' should be in terms" + for term in must_exclude: + assert term not in terms, f"'{term}' should not be in terms" + + +@pytest.mark.parametrize( + "query,excluded", + [ + ("AI is a new ML model", {"ai", "is", "a", "ml"}), # Short words filtered + ], +) +def test_extract_query_terms_short_words(query, excluded): + """Should filter words with 2 or fewer characters.""" + terms = extract_query_terms(query) + for term in excluded: + assert term not in terms + + +@pytest.mark.parametrize( + "word", + ["the", "is", "are", "was", "were", "be", "been", "have", "has", "had", + "do", "does", "did", "to", "of", "in", "for", "on", "with", "at", "by"], +) +def test_common_stopwords_in_set(word): + """Verify common stopwords are in the STOPWORDS set.""" + assert word in STOPWORDS + + +# ============================================================================ +# apply_query_term_boost tests +# ============================================================================ + + +def _make_chunk(content: str, source_id: int = 1, score: float = 0.5): + """Create a mock chunk with given content and score.""" + chunk = MagicMock() + chunk.content = content + chunk.source_id = source_id + chunk.relevance_score = score + return chunk + + +@pytest.mark.parametrize( + "content,query_terms,initial_score,expected_boost_fraction", + [ + ("machine learning is powerful", {"machine", "learning"}, 0.5, 1.0), # Both match + ("machine vision systems", {"machine", "learning"}, 0.5, 0.5), # One of two + ("deep neural networks", {"machine", "learning"}, 0.5, 0.0), # No match + ("MACHINE Learning AlGoRiThMs", {"machine", "learning", "algorithms"}, 0.5, 1.0), # Case insensitive + ], +) +def test_apply_query_term_boost(content, query_terms, initial_score, expected_boost_fraction): + """Should boost chunks based on query term matches.""" + chunks = [_make_chunk(content, score=initial_score)] + apply_query_term_boost(chunks, query_terms) + expected = initial_score + QUERY_TERM_BOOST * expected_boost_fraction + assert chunks[0].relevance_score == pytest.approx(expected) + + +def test_apply_query_term_boost_empty_inputs(): + """Should handle empty query_terms or chunks.""" + chunks = [_make_chunk("machine learning", score=0.5)] + apply_query_term_boost(chunks, set()) + assert chunks[0].relevance_score == 0.5 + + apply_query_term_boost([], {"machine"}) # Should not raise + + +def test_apply_query_term_boost_none_values(): + """Should handle None content and relevance_score.""" + chunk_none_content = MagicMock() + chunk_none_content.content = None + chunk_none_content.relevance_score = 0.5 + apply_query_term_boost([chunk_none_content], {"machine"}) + assert chunk_none_content.relevance_score == 0.5 + + chunk_none_score = MagicMock() + chunk_none_score.content = "machine learning" + chunk_none_score.relevance_score = None + apply_query_term_boost([chunk_none_score], {"machine", "learning"}) + assert chunk_none_score.relevance_score == pytest.approx(QUERY_TERM_BOOST) + + +def test_apply_query_term_boost_multiple_chunks(): + """Should boost each chunk independently.""" + chunks = [ + _make_chunk("machine learning", score=0.5), + _make_chunk("deep networks", score=0.6), + _make_chunk("machine vision", score=0.4), + ] + query_terms = {"machine", "learning"} + apply_query_term_boost(chunks, query_terms) + + assert chunks[0].relevance_score == pytest.approx(0.5 + QUERY_TERM_BOOST) + assert chunks[1].relevance_score == 0.6 # No match + assert chunks[2].relevance_score == pytest.approx(0.4 + QUERY_TERM_BOOST * 0.5) + + +# ============================================================================ +# deduplicate_by_source tests +# ============================================================================ + + +def _make_source_chunk(source_id: int, score: float): + """Create a mock chunk with given source_id and score.""" + chunk = MagicMock() + chunk.source_id = source_id + chunk.relevance_score = score + return chunk + + +@pytest.mark.parametrize( + "chunks_data,expected_count,expected_scores", + [ + # Multiple chunks per source - keep highest + ([(1, 0.5), (1, 0.8), (1, 0.3), (2, 0.6)], 2, {1: 0.8, 2: 0.6}), + # Single chunk per source - keep all + ([(1, 0.5), (2, 0.6), (3, 0.7)], 3, {1: 0.5, 2: 0.6, 3: 0.7}), + # Empty list + ([], 0, {}), + ], +) +def test_deduplicate_by_source(chunks_data, expected_count, expected_scores): + """Should keep only highest scoring chunk per source.""" + chunks = [_make_source_chunk(sid, score) for sid, score in chunks_data] + result = deduplicate_by_source(chunks) + + assert len(result) == expected_count + for chunk in result: + assert chunk.relevance_score == expected_scores[chunk.source_id] + + +def test_deduplicate_by_source_preserves_objects(): + """Should return the actual chunk objects, not copies.""" + chunk1 = _make_source_chunk(1, 0.5) + chunk2 = _make_source_chunk(1, 0.8) + result = deduplicate_by_source([chunk1, chunk2]) + assert result[0] is chunk2 + + +def test_deduplicate_by_source_none_scores(): + """Should handle None relevance_score as 0.""" + chunk1 = _make_source_chunk(1, None) + chunk2 = _make_source_chunk(1, 0.5) + result = deduplicate_by_source([chunk1, chunk2]) + assert result[0].relevance_score == 0.5 + + +# ============================================================================ +# apply_title_boost tests +# ============================================================================ + + +def _make_title_chunk(source_id: int, score: float = 0.5): + """Create a mock chunk for title boost tests.""" + chunk = MagicMock() + chunk.source_id = source_id + chunk.relevance_score = score + return chunk + + +@pytest.mark.parametrize( + "title,query_terms,initial_score,expected_boost_fraction", + [ + ("Machine Learning Tutorial", {"machine", "learning"}, 0.5, 1.0), + ("Machine Vision Systems", {"machine", "learning"}, 0.5, 0.5), + ("Deep Neural Networks", {"machine", "learning"}, 0.5, 0.0), + ("MACHINE LEARNING Tutorial", {"machine", "learning"}, 0.5, 1.0), # Case insensitive + ], +) +@patch("memory.api.search.search.make_session") +def test_apply_title_boost(mock_make_session, title, query_terms, initial_score, expected_boost_fraction): + """Should boost chunks when title matches query terms.""" + mock_session = MagicMock() + mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_make_session.return_value.__exit__ = MagicMock(return_value=None) + + mock_source = MagicMock() + mock_source.id = 1 + mock_source.title = title + mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] + + chunks = [_make_title_chunk(1, initial_score)] + apply_title_boost(chunks, query_terms) + + expected = initial_score + TITLE_MATCH_BOOST * expected_boost_fraction + assert chunks[0].relevance_score == pytest.approx(expected) + + +def test_apply_title_boost_empty_inputs(): + """Should not modify chunks if query_terms or chunks is empty.""" + chunks = [_make_title_chunk(1, 0.5)] + apply_title_boost(chunks, set()) + assert chunks[0].relevance_score == 0.5 + + apply_title_boost([], {"machine"}) # Should not raise + + +@patch("memory.api.search.search.make_session") +def test_apply_title_boost_none_title(mock_make_session): + """Should handle sources with None or missing title.""" + mock_session = MagicMock() + mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_make_session.return_value.__exit__ = MagicMock(return_value=None) + + # Source with None title + mock_source = MagicMock() + mock_source.id = 1 + mock_source.title = None + mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] + + chunks = [_make_title_chunk(1, 0.5)] + apply_title_boost(chunks, {"machine"}) + assert chunks[0].relevance_score == 0.5 + + +# ============================================================================ +# fuse_scores_rrf tests +# ============================================================================ + + +@pytest.mark.parametrize( + "embedding_scores,bm25_scores,expected_key,expected_score", + [ + # Both sources have same ranking + ({"a": 0.9, "b": 0.7}, {"a": 0.8, "b": 0.6}, "a", 2 / (RRF_K + 1)), + # Item only in embeddings + ({"a": 0.9, "b": 0.7}, {"a": 0.8}, "b", 1 / (RRF_K + 2)), + # Item only in BM25 + ({"a": 0.9}, {"a": 0.8, "b": 0.7}, "b", 1 / (RRF_K + 2)), + # Single item in both + ({"a": 0.9}, {"a": 0.8}, "a", 2 / (RRF_K + 1)), + ], +) +def test_fuse_scores_rrf_basic(embedding_scores, bm25_scores, expected_key, expected_score): + """Should compute RRF scores correctly.""" + result = fuse_scores_rrf(embedding_scores, bm25_scores) + assert result[expected_key] == pytest.approx(expected_score) + + +def test_fuse_scores_rrf_different_rankings(): + """Should handle items ranked differently in each source.""" + embedding_scores = {"a": 0.9, "b": 0.5} # a=1, b=2 + bm25_scores = {"a": 0.3, "b": 0.8} # b=1, a=2 + + result = fuse_scores_rrf(embedding_scores, bm25_scores) + + # Both should have same RRF score (1/61 + 1/62) + expected = 1 / (RRF_K + 1) + 1 / (RRF_K + 2) + assert result["a"] == pytest.approx(expected) + assert result["b"] == pytest.approx(expected) + + +@pytest.mark.parametrize( + "embedding_scores,bm25_scores,expected_len", + [ + ({}, {}, 0), + ({}, {"a": 0.8, "b": 0.6}, 2), + ({"a": 0.9, "b": 0.7}, {}, 2), + ], +) +def test_fuse_scores_rrf_empty_inputs(embedding_scores, bm25_scores, expected_len): + """Should handle empty inputs gracefully.""" + result = fuse_scores_rrf(embedding_scores, bm25_scores) + assert len(result) == expected_len + + +def test_fuse_scores_rrf_many_items(): + """Should handle many items correctly.""" + embedding_scores = {str(i): 1.0 - i * 0.01 for i in range(100)} + bm25_scores = {str(i): 1.0 - i * 0.01 for i in range(100)} + + result = fuse_scores_rrf(embedding_scores, bm25_scores) + + assert len(result) == 100 + assert result["0"] > result["99"] # First should have highest score + + +def test_fuse_scores_rrf_only_ranks_matter(): + """RRF should only care about ranks, not score magnitudes.""" + # Same ranking, different score scales + result1 = fuse_scores_rrf( + {"a": 0.99, "b": 0.98, "c": 0.97}, + {"a": 100, "b": 50, "c": 1}, + ) + result2 = fuse_scores_rrf( + {"a": 0.5, "b": 0.4, "c": 0.3}, + {"a": 0.9, "b": 0.8, "c": 0.7}, + ) + + # RRF scores should be identical since rankings are the same + assert result1["a"] == pytest.approx(result2["a"]) + assert result1["b"] == pytest.approx(result2["b"]) + assert result1["c"] == pytest.approx(result2["c"])