mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Add comprehensive tests for search improvements
- Add tests for extract_query_terms (stopword filtering, short words) - Add tests for apply_query_term_boost (boost calculations, edge cases) - Add tests for deduplicate_by_source (keeps highest per source) - Add tests for apply_title_boost (title matching with mocked DB) - Add tests for fuse_scores_rrf (RRF score fusion, ranking behavior) - Add tests for rerank module (VoyageAI reranker mocking) Uses pytest.mark.parametrize for concise, data-driven tests. 77 tests total covering all new search functionality. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
c6cb793cdf
commit
09215adf9a
94
src/memory/api/search/rerank.py
Normal file
94
src/memory/api/search/rerank.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""
|
||||
Cross-encoder reranking using VoyageAI's reranker.
|
||||
|
||||
Reranking improves search precision by using a cross-encoder model that
|
||||
sees query and document together, rather than comparing embeddings separately.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import voyageai
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.models import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# VoyageAI reranker models
|
||||
# rerank-2: More accurate, slower
|
||||
# rerank-2-lite: Faster, slightly less accurate
|
||||
DEFAULT_RERANK_MODEL = "rerank-2-lite"
|
||||
|
||||
|
||||
async def rerank_chunks(
|
||||
query: str,
|
||||
chunks: list[Chunk],
|
||||
model: str = DEFAULT_RERANK_MODEL,
|
||||
top_k: Optional[int] = None,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Rerank chunks using VoyageAI's cross-encoder reranker.
|
||||
|
||||
Cross-encoders are more accurate than bi-encoders (embeddings) because
|
||||
they see query and document together, allowing for deeper semantic matching.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
chunks: List of candidate chunks to rerank
|
||||
model: VoyageAI reranker model to use
|
||||
top_k: If set, only return top k results
|
||||
|
||||
Returns:
|
||||
Chunks sorted by reranker relevance score
|
||||
"""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
if not query.strip():
|
||||
return chunks
|
||||
|
||||
# Extract text content from chunks
|
||||
documents = []
|
||||
chunk_map = {} # Map index to chunk
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
content = chunk.content or ""
|
||||
if not content and hasattr(chunk, "data"):
|
||||
try:
|
||||
data = chunk.data
|
||||
content = "\n".join(str(d) for d in data if isinstance(d, str))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if content:
|
||||
documents.append(content[:8000]) # VoyageAI has length limits
|
||||
chunk_map[len(documents) - 1] = chunk
|
||||
|
||||
if not documents:
|
||||
return chunks
|
||||
|
||||
try:
|
||||
vo = voyageai.Client()
|
||||
result = await asyncio.to_thread(
|
||||
vo.rerank,
|
||||
query=query,
|
||||
documents=documents,
|
||||
model=model,
|
||||
top_k=top_k or len(documents),
|
||||
)
|
||||
|
||||
# Map results back to chunks with updated scores
|
||||
reranked = []
|
||||
for item in result.results:
|
||||
chunk = chunk_map.get(item.index)
|
||||
if chunk:
|
||||
chunk.relevance_score = item.relevance_score
|
||||
reranked.append(chunk)
|
||||
|
||||
return reranked
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Reranking failed, returning original order: {e}")
|
||||
return chunks
|
||||
313
tests/memory/api/search/test_rerank.py
Normal file
313
tests/memory/api/search/test_rerank.py
Normal file
@ -0,0 +1,313 @@
|
||||
"""
|
||||
Tests for the rerank module (VoyageAI cross-encoder reranking).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from memory.api.search.rerank import rerank_chunks, DEFAULT_RERANK_MODEL
|
||||
|
||||
|
||||
class MockRerankResult:
|
||||
"""Mock for VoyageAI rerank result item."""
|
||||
|
||||
def __init__(self, index: int, relevance_score: float):
|
||||
self.index = index
|
||||
self.relevance_score = relevance_score
|
||||
|
||||
|
||||
class MockRerankResponse:
|
||||
"""Mock for VoyageAI rerank response."""
|
||||
|
||||
def __init__(self, results: list[MockRerankResult]):
|
||||
self.results = results
|
||||
|
||||
|
||||
def _make_chunk(content: str = "test content", score: float = 0.5):
|
||||
"""Create a mock chunk."""
|
||||
chunk = MagicMock()
|
||||
chunk.content = content
|
||||
chunk.relevance_score = score
|
||||
chunk.data = None
|
||||
return chunk
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Basic reranking tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_basic(mock_voyageai):
|
||||
"""Should rerank chunks using VoyageAI."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=1, relevance_score=0.9),
|
||||
MockRerankResult(index=0, relevance_score=0.7),
|
||||
])
|
||||
|
||||
chunks = [_make_chunk("first", 0.5), _make_chunk("second", 0.6)]
|
||||
result = await rerank_chunks("test query", chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].relevance_score == 0.9
|
||||
assert result[1].relevance_score == 0.7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_reorders_correctly(mock_voyageai):
|
||||
"""Should correctly reorder multiple chunks."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=2, relevance_score=0.95),
|
||||
MockRerankResult(index=0, relevance_score=0.85),
|
||||
MockRerankResult(index=1, relevance_score=0.75),
|
||||
])
|
||||
|
||||
chunk_a = _make_chunk("chunk a", 0.5)
|
||||
chunk_b = _make_chunk("chunk b", 0.6)
|
||||
chunk_c = _make_chunk("chunk c", 0.4)
|
||||
|
||||
result = await rerank_chunks("query", [chunk_a, chunk_b, chunk_c])
|
||||
|
||||
assert result[0] is chunk_c
|
||||
assert result[1] is chunk_a
|
||||
assert result[2] is chunk_b
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Empty/edge case tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("query", ["", " ", "\t\n"])
|
||||
async def test_rerank_chunks_empty_query(query):
|
||||
"""Should return original chunks for empty/whitespace query."""
|
||||
chunks = [_make_chunk("test", 0.5)]
|
||||
result = await rerank_chunks(query, chunks)
|
||||
assert result == chunks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_chunks_empty_chunks():
|
||||
"""Should return empty list for empty input."""
|
||||
result = await rerank_chunks("test query", [])
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_all_empty_content(mock_voyageai):
|
||||
"""Should return original chunks if all have empty content."""
|
||||
chunk1 = MagicMock()
|
||||
chunk1.content = ""
|
||||
chunk1.data = None
|
||||
chunk1.relevance_score = 0.5
|
||||
|
||||
chunk2 = MagicMock()
|
||||
chunk2.content = None
|
||||
chunk2.data = []
|
||||
chunk2.relevance_score = 0.6
|
||||
|
||||
chunks = [chunk1, chunk2]
|
||||
result = await rerank_chunks("query", chunks)
|
||||
|
||||
assert result == chunks
|
||||
mock_voyageai.Client.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model and parameter tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["rerank-2", "rerank-2-lite", "custom-model"])
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_uses_specified_model(mock_voyageai, model):
|
||||
"""Should use specified model."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.8),
|
||||
])
|
||||
|
||||
await rerank_chunks("query", [_make_chunk()], model=model)
|
||||
|
||||
call_kwargs = mock_client.rerank.call_args[1]
|
||||
assert call_kwargs["model"] == model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_uses_default_model(mock_voyageai):
|
||||
"""Should use default model when not specified."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.8),
|
||||
])
|
||||
|
||||
await rerank_chunks("query", [_make_chunk()])
|
||||
|
||||
call_kwargs = mock_client.rerank.call_args[1]
|
||||
assert call_kwargs["model"] == DEFAULT_RERANK_MODEL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("top_k", [1, 5, 10])
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_respects_top_k(mock_voyageai, top_k):
|
||||
"""Should pass top_k to VoyageAI."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.8),
|
||||
])
|
||||
|
||||
chunks = [_make_chunk() for _ in range(10)]
|
||||
await rerank_chunks("query", chunks, top_k=top_k)
|
||||
|
||||
call_kwargs = mock_client.rerank.call_args[1]
|
||||
assert call_kwargs["top_k"] == top_k
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Content handling tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_skips_none_content(mock_voyageai):
|
||||
"""Should skip chunks with None content."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.8),
|
||||
])
|
||||
|
||||
chunk_with = _make_chunk("has content", 0.5)
|
||||
chunk_without = _make_chunk(None, 0.5)
|
||||
chunk_without.content = None
|
||||
|
||||
await rerank_chunks("query", [chunk_with, chunk_without])
|
||||
|
||||
call_kwargs = mock_client.rerank.call_args[1]
|
||||
assert len(call_kwargs["documents"]) == 1
|
||||
assert call_kwargs["documents"][0] == "has content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_truncates_long_content(mock_voyageai):
|
||||
"""Should truncate content to 8000 characters."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.8),
|
||||
])
|
||||
|
||||
long_content = "x" * 10000
|
||||
await rerank_chunks("query", [_make_chunk(long_content)])
|
||||
|
||||
call_kwargs = mock_client.rerank.call_args[1]
|
||||
assert len(call_kwargs["documents"][0]) == 8000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_uses_data_fallback(mock_voyageai):
|
||||
"""Should fall back to data attribute if content is empty."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.8),
|
||||
])
|
||||
|
||||
chunk = MagicMock()
|
||||
chunk.content = ""
|
||||
chunk.data = ["text from data", 123, "more text"]
|
||||
chunk.relevance_score = 0.5
|
||||
|
||||
await rerank_chunks("query", [chunk])
|
||||
|
||||
call_kwargs = mock_client.rerank.call_args[1]
|
||||
assert "text from data" in call_kwargs["documents"][0]
|
||||
assert "more text" in call_kwargs["documents"][0]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error handling tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_handles_api_error(mock_voyageai):
|
||||
"""Should return original chunks on API error."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.side_effect = Exception("API error")
|
||||
|
||||
chunks = [_make_chunk("test", 0.5)]
|
||||
result = await rerank_chunks("query", chunks)
|
||||
|
||||
assert result == chunks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_handles_missing_index(mock_voyageai):
|
||||
"""Should handle missing indices gracefully."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.8),
|
||||
MockRerankResult(index=99, relevance_score=0.7), # Invalid index
|
||||
])
|
||||
|
||||
result = await rerank_chunks("query", [_make_chunk()])
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Object preservation tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_preserves_objects(mock_voyageai):
|
||||
"""Should return the same chunk objects, not copies."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.8),
|
||||
])
|
||||
|
||||
chunk = _make_chunk("test", 0.5)
|
||||
result = await rerank_chunks("query", [chunk])
|
||||
|
||||
assert result[0] is chunk
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("memory.api.search.rerank.voyageai")
|
||||
async def test_rerank_chunks_updates_scores(mock_voyageai):
|
||||
"""Should update chunk relevance_score from reranker."""
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock_client.rerank.return_value = MockRerankResponse([
|
||||
MockRerankResult(index=0, relevance_score=0.95),
|
||||
])
|
||||
|
||||
chunk = _make_chunk("test", 0.5)
|
||||
result = await rerank_chunks("query", [chunk])
|
||||
|
||||
assert result[0].relevance_score == 0.95
|
||||
353
tests/memory/api/search/test_search.py
Normal file
353
tests/memory/api/search/test_search.py
Normal file
@ -0,0 +1,353 @@
|
||||
"""
|
||||
Tests for search module functions including RRF fusion, query term boosting,
|
||||
title boosting, and source deduplication.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from memory.api.search.search import (
|
||||
extract_query_terms,
|
||||
apply_query_term_boost,
|
||||
deduplicate_by_source,
|
||||
apply_title_boost,
|
||||
fuse_scores_rrf,
|
||||
STOPWORDS,
|
||||
QUERY_TERM_BOOST,
|
||||
TITLE_MATCH_BOOST,
|
||||
RRF_K,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# extract_query_terms tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,expected",
|
||||
[
|
||||
("machine learning algorithms", {"machine", "learning", "algorithms"}),
|
||||
("MACHINE Learning ALGORITHMS", {"machine", "learning", "algorithms"}),
|
||||
("", set()),
|
||||
("the is a an of to", set()), # Only stopwords
|
||||
],
|
||||
)
|
||||
def test_extract_query_terms_basic(query, expected):
|
||||
"""Should extract meaningful terms, lowercase them, and filter stopwords."""
|
||||
assert extract_query_terms(query) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,must_include,must_exclude",
|
||||
[
|
||||
(
|
||||
"the quick brown fox jumps with the lazy dog",
|
||||
{"quick", "brown", "jumps", "lazy", "fox", "dog"},
|
||||
{"the", "with"},
|
||||
),
|
||||
(
|
||||
"what is the best approach for neural networks",
|
||||
{"best", "approach", "neural", "networks"},
|
||||
{"what", "the", "for"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_query_terms_filtering(query, must_include, must_exclude):
|
||||
"""Should filter stopwords while keeping meaningful terms."""
|
||||
terms = extract_query_terms(query)
|
||||
for term in must_include:
|
||||
assert term in terms, f"'{term}' should be in terms"
|
||||
for term in must_exclude:
|
||||
assert term not in terms, f"'{term}' should not be in terms"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,excluded",
|
||||
[
|
||||
("AI is a new ML model", {"ai", "is", "a", "ml"}), # Short words filtered
|
||||
],
|
||||
)
|
||||
def test_extract_query_terms_short_words(query, excluded):
|
||||
"""Should filter words with 2 or fewer characters."""
|
||||
terms = extract_query_terms(query)
|
||||
for term in excluded:
|
||||
assert term not in terms
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"word",
|
||||
["the", "is", "are", "was", "were", "be", "been", "have", "has", "had",
|
||||
"do", "does", "did", "to", "of", "in", "for", "on", "with", "at", "by"],
|
||||
)
|
||||
def test_common_stopwords_in_set(word):
|
||||
"""Verify common stopwords are in the STOPWORDS set."""
|
||||
assert word in STOPWORDS
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# apply_query_term_boost tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_chunk(content: str, source_id: int = 1, score: float = 0.5):
|
||||
"""Create a mock chunk with given content and score."""
|
||||
chunk = MagicMock()
|
||||
chunk.content = content
|
||||
chunk.source_id = source_id
|
||||
chunk.relevance_score = score
|
||||
return chunk
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content,query_terms,initial_score,expected_boost_fraction",
|
||||
[
|
||||
("machine learning is powerful", {"machine", "learning"}, 0.5, 1.0), # Both match
|
||||
("machine vision systems", {"machine", "learning"}, 0.5, 0.5), # One of two
|
||||
("deep neural networks", {"machine", "learning"}, 0.5, 0.0), # No match
|
||||
("MACHINE Learning AlGoRiThMs", {"machine", "learning", "algorithms"}, 0.5, 1.0), # Case insensitive
|
||||
],
|
||||
)
|
||||
def test_apply_query_term_boost(content, query_terms, initial_score, expected_boost_fraction):
|
||||
"""Should boost chunks based on query term matches."""
|
||||
chunks = [_make_chunk(content, score=initial_score)]
|
||||
apply_query_term_boost(chunks, query_terms)
|
||||
expected = initial_score + QUERY_TERM_BOOST * expected_boost_fraction
|
||||
assert chunks[0].relevance_score == pytest.approx(expected)
|
||||
|
||||
|
||||
def test_apply_query_term_boost_empty_inputs():
|
||||
"""Should handle empty query_terms or chunks."""
|
||||
chunks = [_make_chunk("machine learning", score=0.5)]
|
||||
apply_query_term_boost(chunks, set())
|
||||
assert chunks[0].relevance_score == 0.5
|
||||
|
||||
apply_query_term_boost([], {"machine"}) # Should not raise
|
||||
|
||||
|
||||
def test_apply_query_term_boost_none_values():
|
||||
"""Should handle None content and relevance_score."""
|
||||
chunk_none_content = MagicMock()
|
||||
chunk_none_content.content = None
|
||||
chunk_none_content.relevance_score = 0.5
|
||||
apply_query_term_boost([chunk_none_content], {"machine"})
|
||||
assert chunk_none_content.relevance_score == 0.5
|
||||
|
||||
chunk_none_score = MagicMock()
|
||||
chunk_none_score.content = "machine learning"
|
||||
chunk_none_score.relevance_score = None
|
||||
apply_query_term_boost([chunk_none_score], {"machine", "learning"})
|
||||
assert chunk_none_score.relevance_score == pytest.approx(QUERY_TERM_BOOST)
|
||||
|
||||
|
||||
def test_apply_query_term_boost_multiple_chunks():
|
||||
"""Should boost each chunk independently."""
|
||||
chunks = [
|
||||
_make_chunk("machine learning", score=0.5),
|
||||
_make_chunk("deep networks", score=0.6),
|
||||
_make_chunk("machine vision", score=0.4),
|
||||
]
|
||||
query_terms = {"machine", "learning"}
|
||||
apply_query_term_boost(chunks, query_terms)
|
||||
|
||||
assert chunks[0].relevance_score == pytest.approx(0.5 + QUERY_TERM_BOOST)
|
||||
assert chunks[1].relevance_score == 0.6 # No match
|
||||
assert chunks[2].relevance_score == pytest.approx(0.4 + QUERY_TERM_BOOST * 0.5)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# deduplicate_by_source tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_source_chunk(source_id: int, score: float):
|
||||
"""Create a mock chunk with given source_id and score."""
|
||||
chunk = MagicMock()
|
||||
chunk.source_id = source_id
|
||||
chunk.relevance_score = score
|
||||
return chunk
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunks_data,expected_count,expected_scores",
|
||||
[
|
||||
# Multiple chunks per source - keep highest
|
||||
([(1, 0.5), (1, 0.8), (1, 0.3), (2, 0.6)], 2, {1: 0.8, 2: 0.6}),
|
||||
# Single chunk per source - keep all
|
||||
([(1, 0.5), (2, 0.6), (3, 0.7)], 3, {1: 0.5, 2: 0.6, 3: 0.7}),
|
||||
# Empty list
|
||||
([], 0, {}),
|
||||
],
|
||||
)
|
||||
def test_deduplicate_by_source(chunks_data, expected_count, expected_scores):
|
||||
"""Should keep only highest scoring chunk per source."""
|
||||
chunks = [_make_source_chunk(sid, score) for sid, score in chunks_data]
|
||||
result = deduplicate_by_source(chunks)
|
||||
|
||||
assert len(result) == expected_count
|
||||
for chunk in result:
|
||||
assert chunk.relevance_score == expected_scores[chunk.source_id]
|
||||
|
||||
|
||||
def test_deduplicate_by_source_preserves_objects():
|
||||
"""Should return the actual chunk objects, not copies."""
|
||||
chunk1 = _make_source_chunk(1, 0.5)
|
||||
chunk2 = _make_source_chunk(1, 0.8)
|
||||
result = deduplicate_by_source([chunk1, chunk2])
|
||||
assert result[0] is chunk2
|
||||
|
||||
|
||||
def test_deduplicate_by_source_none_scores():
|
||||
"""Should handle None relevance_score as 0."""
|
||||
chunk1 = _make_source_chunk(1, None)
|
||||
chunk2 = _make_source_chunk(1, 0.5)
|
||||
result = deduplicate_by_source([chunk1, chunk2])
|
||||
assert result[0].relevance_score == 0.5
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# apply_title_boost tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_title_chunk(source_id: int, score: float = 0.5):
|
||||
"""Create a mock chunk for title boost tests."""
|
||||
chunk = MagicMock()
|
||||
chunk.source_id = source_id
|
||||
chunk.relevance_score = score
|
||||
return chunk
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"title,query_terms,initial_score,expected_boost_fraction",
|
||||
[
|
||||
("Machine Learning Tutorial", {"machine", "learning"}, 0.5, 1.0),
|
||||
("Machine Vision Systems", {"machine", "learning"}, 0.5, 0.5),
|
||||
("Deep Neural Networks", {"machine", "learning"}, 0.5, 0.0),
|
||||
("MACHINE LEARNING Tutorial", {"machine", "learning"}, 0.5, 1.0), # Case insensitive
|
||||
],
|
||||
)
|
||||
@patch("memory.api.search.search.make_session")
|
||||
def test_apply_title_boost(mock_make_session, title, query_terms, initial_score, expected_boost_fraction):
|
||||
"""Should boost chunks when title matches query terms."""
|
||||
mock_session = MagicMock()
|
||||
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = title
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_title_chunk(1, initial_score)]
|
||||
apply_title_boost(chunks, query_terms)
|
||||
|
||||
expected = initial_score + TITLE_MATCH_BOOST * expected_boost_fraction
|
||||
assert chunks[0].relevance_score == pytest.approx(expected)
|
||||
|
||||
|
||||
def test_apply_title_boost_empty_inputs():
|
||||
"""Should not modify chunks if query_terms or chunks is empty."""
|
||||
chunks = [_make_title_chunk(1, 0.5)]
|
||||
apply_title_boost(chunks, set())
|
||||
assert chunks[0].relevance_score == 0.5
|
||||
|
||||
apply_title_boost([], {"machine"}) # Should not raise
|
||||
|
||||
|
||||
@patch("memory.api.search.search.make_session")
|
||||
def test_apply_title_boost_none_title(mock_make_session):
|
||||
"""Should handle sources with None or missing title."""
|
||||
mock_session = MagicMock()
|
||||
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Source with None title
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = None
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_title_chunk(1, 0.5)]
|
||||
apply_title_boost(chunks, {"machine"})
|
||||
assert chunks[0].relevance_score == 0.5
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# fuse_scores_rrf tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_scores,bm25_scores,expected_key,expected_score",
|
||||
[
|
||||
# Both sources have same ranking
|
||||
({"a": 0.9, "b": 0.7}, {"a": 0.8, "b": 0.6}, "a", 2 / (RRF_K + 1)),
|
||||
# Item only in embeddings
|
||||
({"a": 0.9, "b": 0.7}, {"a": 0.8}, "b", 1 / (RRF_K + 2)),
|
||||
# Item only in BM25
|
||||
({"a": 0.9}, {"a": 0.8, "b": 0.7}, "b", 1 / (RRF_K + 2)),
|
||||
# Single item in both
|
||||
({"a": 0.9}, {"a": 0.8}, "a", 2 / (RRF_K + 1)),
|
||||
],
|
||||
)
|
||||
def test_fuse_scores_rrf_basic(embedding_scores, bm25_scores, expected_key, expected_score):
|
||||
"""Should compute RRF scores correctly."""
|
||||
result = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||
assert result[expected_key] == pytest.approx(expected_score)
|
||||
|
||||
|
||||
def test_fuse_scores_rrf_different_rankings():
|
||||
"""Should handle items ranked differently in each source."""
|
||||
embedding_scores = {"a": 0.9, "b": 0.5} # a=1, b=2
|
||||
bm25_scores = {"a": 0.3, "b": 0.8} # b=1, a=2
|
||||
|
||||
result = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||
|
||||
# Both should have same RRF score (1/61 + 1/62)
|
||||
expected = 1 / (RRF_K + 1) + 1 / (RRF_K + 2)
|
||||
assert result["a"] == pytest.approx(expected)
|
||||
assert result["b"] == pytest.approx(expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_scores,bm25_scores,expected_len",
|
||||
[
|
||||
({}, {}, 0),
|
||||
({}, {"a": 0.8, "b": 0.6}, 2),
|
||||
({"a": 0.9, "b": 0.7}, {}, 2),
|
||||
],
|
||||
)
|
||||
def test_fuse_scores_rrf_empty_inputs(embedding_scores, bm25_scores, expected_len):
|
||||
"""Should handle empty inputs gracefully."""
|
||||
result = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||
assert len(result) == expected_len
|
||||
|
||||
|
||||
def test_fuse_scores_rrf_many_items():
|
||||
"""Should handle many items correctly."""
|
||||
embedding_scores = {str(i): 1.0 - i * 0.01 for i in range(100)}
|
||||
bm25_scores = {str(i): 1.0 - i * 0.01 for i in range(100)}
|
||||
|
||||
result = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||
|
||||
assert len(result) == 100
|
||||
assert result["0"] > result["99"] # First should have highest score
|
||||
|
||||
|
||||
def test_fuse_scores_rrf_only_ranks_matter():
|
||||
"""RRF should only care about ranks, not score magnitudes."""
|
||||
# Same ranking, different score scales
|
||||
result1 = fuse_scores_rrf(
|
||||
{"a": 0.99, "b": 0.98, "c": 0.97},
|
||||
{"a": 100, "b": 50, "c": 1},
|
||||
)
|
||||
result2 = fuse_scores_rrf(
|
||||
{"a": 0.5, "b": 0.4, "c": 0.3},
|
||||
{"a": 0.9, "b": 0.8, "c": 0.7},
|
||||
)
|
||||
|
||||
# RRF scores should be identical since rankings are the same
|
||||
assert result1["a"] == pytest.approx(result2["a"])
|
||||
assert result1["b"] == pytest.approx(result2["b"])
|
||||
assert result1["c"] == pytest.approx(result2["c"])
|
||||
Loading…
x
Reference in New Issue
Block a user