mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 17:22:58 +01:00
Add comprehensive tests for search improvements
- Add tests for extract_query_terms (stopword filtering, short words) - Add tests for apply_query_term_boost (boost calculations, edge cases) - Add tests for deduplicate_by_source (keeps highest per source) - Add tests for apply_title_boost (title matching with mocked DB) - Add tests for fuse_scores_rrf (RRF score fusion, ranking behavior) - Add tests for rerank module (VoyageAI reranker mocking) Uses pytest.mark.parametrize for concise, data-driven tests. 77 tests total covering all new search functionality. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
c6cb793cdf
commit
09215adf9a
94
src/memory/api/search/rerank.py
Normal file
94
src/memory/api/search/rerank.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
"""
|
||||||
|
Cross-encoder reranking using VoyageAI's reranker.
|
||||||
|
|
||||||
|
Reranking improves search precision by using a cross-encoder model that
|
||||||
|
sees query and document together, rather than comparing embeddings separately.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import voyageai
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.models import Chunk
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# VoyageAI reranker models
|
||||||
|
# rerank-2: More accurate, slower
|
||||||
|
# rerank-2-lite: Faster, slightly less accurate
|
||||||
|
DEFAULT_RERANK_MODEL = "rerank-2-lite"
|
||||||
|
|
||||||
|
|
||||||
|
async def rerank_chunks(
|
||||||
|
query: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
model: str = DEFAULT_RERANK_MODEL,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Rerank chunks using VoyageAI's cross-encoder reranker.
|
||||||
|
|
||||||
|
Cross-encoders are more accurate than bi-encoders (embeddings) because
|
||||||
|
they see query and document together, allowing for deeper semantic matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
chunks: List of candidate chunks to rerank
|
||||||
|
model: VoyageAI reranker model to use
|
||||||
|
top_k: If set, only return top k results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chunks sorted by reranker relevance score
|
||||||
|
"""
|
||||||
|
if not chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not query.strip():
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Extract text content from chunks
|
||||||
|
documents = []
|
||||||
|
chunk_map = {} # Map index to chunk
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
content = chunk.content or ""
|
||||||
|
if not content and hasattr(chunk, "data"):
|
||||||
|
try:
|
||||||
|
data = chunk.data
|
||||||
|
content = "\n".join(str(d) for d in data if isinstance(d, str))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if content:
|
||||||
|
documents.append(content[:8000]) # VoyageAI has length limits
|
||||||
|
chunk_map[len(documents) - 1] = chunk
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
try:
|
||||||
|
vo = voyageai.Client()
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
vo.rerank,
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
model=model,
|
||||||
|
top_k=top_k or len(documents),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map results back to chunks with updated scores
|
||||||
|
reranked = []
|
||||||
|
for item in result.results:
|
||||||
|
chunk = chunk_map.get(item.index)
|
||||||
|
if chunk:
|
||||||
|
chunk.relevance_score = item.relevance_score
|
||||||
|
reranked.append(chunk)
|
||||||
|
|
||||||
|
return reranked
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Reranking failed, returning original order: {e}")
|
||||||
|
return chunks
|
||||||
313
tests/memory/api/search/test_rerank.py
Normal file
313
tests/memory/api/search/test_rerank.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
"""
|
||||||
|
Tests for the rerank module (VoyageAI cross-encoder reranking).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from memory.api.search.rerank import rerank_chunks, DEFAULT_RERANK_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
class MockRerankResult:
|
||||||
|
"""Mock for VoyageAI rerank result item."""
|
||||||
|
|
||||||
|
def __init__(self, index: int, relevance_score: float):
|
||||||
|
self.index = index
|
||||||
|
self.relevance_score = relevance_score
|
||||||
|
|
||||||
|
|
||||||
|
class MockRerankResponse:
|
||||||
|
"""Mock for VoyageAI rerank response."""
|
||||||
|
|
||||||
|
def __init__(self, results: list[MockRerankResult]):
|
||||||
|
self.results = results
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chunk(content: str = "test content", score: float = 0.5):
|
||||||
|
"""Create a mock chunk."""
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.content = content
|
||||||
|
chunk.relevance_score = score
|
||||||
|
chunk.data = None
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Basic reranking tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_basic(mock_voyageai):
|
||||||
|
"""Should rerank chunks using VoyageAI."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=1, relevance_score=0.9),
|
||||||
|
MockRerankResult(index=0, relevance_score=0.7),
|
||||||
|
])
|
||||||
|
|
||||||
|
chunks = [_make_chunk("first", 0.5), _make_chunk("second", 0.6)]
|
||||||
|
result = await rerank_chunks("test query", chunks)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0].relevance_score == 0.9
|
||||||
|
assert result[1].relevance_score == 0.7
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_reorders_correctly(mock_voyageai):
|
||||||
|
"""Should correctly reorder multiple chunks."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=2, relevance_score=0.95),
|
||||||
|
MockRerankResult(index=0, relevance_score=0.85),
|
||||||
|
MockRerankResult(index=1, relevance_score=0.75),
|
||||||
|
])
|
||||||
|
|
||||||
|
chunk_a = _make_chunk("chunk a", 0.5)
|
||||||
|
chunk_b = _make_chunk("chunk b", 0.6)
|
||||||
|
chunk_c = _make_chunk("chunk c", 0.4)
|
||||||
|
|
||||||
|
result = await rerank_chunks("query", [chunk_a, chunk_b, chunk_c])
|
||||||
|
|
||||||
|
assert result[0] is chunk_c
|
||||||
|
assert result[1] is chunk_a
|
||||||
|
assert result[2] is chunk_b
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Empty/edge case tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("query", ["", " ", "\t\n"])
|
||||||
|
async def test_rerank_chunks_empty_query(query):
|
||||||
|
"""Should return original chunks for empty/whitespace query."""
|
||||||
|
chunks = [_make_chunk("test", 0.5)]
|
||||||
|
result = await rerank_chunks(query, chunks)
|
||||||
|
assert result == chunks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rerank_chunks_empty_chunks():
|
||||||
|
"""Should return empty list for empty input."""
|
||||||
|
result = await rerank_chunks("test query", [])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_all_empty_content(mock_voyageai):
|
||||||
|
"""Should return original chunks if all have empty content."""
|
||||||
|
chunk1 = MagicMock()
|
||||||
|
chunk1.content = ""
|
||||||
|
chunk1.data = None
|
||||||
|
chunk1.relevance_score = 0.5
|
||||||
|
|
||||||
|
chunk2 = MagicMock()
|
||||||
|
chunk2.content = None
|
||||||
|
chunk2.data = []
|
||||||
|
chunk2.relevance_score = 0.6
|
||||||
|
|
||||||
|
chunks = [chunk1, chunk2]
|
||||||
|
result = await rerank_chunks("query", chunks)
|
||||||
|
|
||||||
|
assert result == chunks
|
||||||
|
mock_voyageai.Client.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Model and parameter tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model", ["rerank-2", "rerank-2-lite", "custom-model"])
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_uses_specified_model(mock_voyageai, model):
|
||||||
|
"""Should use specified model."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.8),
|
||||||
|
])
|
||||||
|
|
||||||
|
await rerank_chunks("query", [_make_chunk()], model=model)
|
||||||
|
|
||||||
|
call_kwargs = mock_client.rerank.call_args[1]
|
||||||
|
assert call_kwargs["model"] == model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_uses_default_model(mock_voyageai):
|
||||||
|
"""Should use default model when not specified."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.8),
|
||||||
|
])
|
||||||
|
|
||||||
|
await rerank_chunks("query", [_make_chunk()])
|
||||||
|
|
||||||
|
call_kwargs = mock_client.rerank.call_args[1]
|
||||||
|
assert call_kwargs["model"] == DEFAULT_RERANK_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("top_k", [1, 5, 10])
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_respects_top_k(mock_voyageai, top_k):
|
||||||
|
"""Should pass top_k to VoyageAI."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.8),
|
||||||
|
])
|
||||||
|
|
||||||
|
chunks = [_make_chunk() for _ in range(10)]
|
||||||
|
await rerank_chunks("query", chunks, top_k=top_k)
|
||||||
|
|
||||||
|
call_kwargs = mock_client.rerank.call_args[1]
|
||||||
|
assert call_kwargs["top_k"] == top_k
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Content handling tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_skips_none_content(mock_voyageai):
|
||||||
|
"""Should skip chunks with None content."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.8),
|
||||||
|
])
|
||||||
|
|
||||||
|
chunk_with = _make_chunk("has content", 0.5)
|
||||||
|
chunk_without = _make_chunk(None, 0.5)
|
||||||
|
chunk_without.content = None
|
||||||
|
|
||||||
|
await rerank_chunks("query", [chunk_with, chunk_without])
|
||||||
|
|
||||||
|
call_kwargs = mock_client.rerank.call_args[1]
|
||||||
|
assert len(call_kwargs["documents"]) == 1
|
||||||
|
assert call_kwargs["documents"][0] == "has content"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_truncates_long_content(mock_voyageai):
|
||||||
|
"""Should truncate content to 8000 characters."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.8),
|
||||||
|
])
|
||||||
|
|
||||||
|
long_content = "x" * 10000
|
||||||
|
await rerank_chunks("query", [_make_chunk(long_content)])
|
||||||
|
|
||||||
|
call_kwargs = mock_client.rerank.call_args[1]
|
||||||
|
assert len(call_kwargs["documents"][0]) == 8000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_uses_data_fallback(mock_voyageai):
|
||||||
|
"""Should fall back to data attribute if content is empty."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.8),
|
||||||
|
])
|
||||||
|
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.content = ""
|
||||||
|
chunk.data = ["text from data", 123, "more text"]
|
||||||
|
chunk.relevance_score = 0.5
|
||||||
|
|
||||||
|
await rerank_chunks("query", [chunk])
|
||||||
|
|
||||||
|
call_kwargs = mock_client.rerank.call_args[1]
|
||||||
|
assert "text from data" in call_kwargs["documents"][0]
|
||||||
|
assert "more text" in call_kwargs["documents"][0]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Error handling tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_handles_api_error(mock_voyageai):
|
||||||
|
"""Should return original chunks on API error."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.side_effect = Exception("API error")
|
||||||
|
|
||||||
|
chunks = [_make_chunk("test", 0.5)]
|
||||||
|
result = await rerank_chunks("query", chunks)
|
||||||
|
|
||||||
|
assert result == chunks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_handles_missing_index(mock_voyageai):
|
||||||
|
"""Should handle missing indices gracefully."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.8),
|
||||||
|
MockRerankResult(index=99, relevance_score=0.7), # Invalid index
|
||||||
|
])
|
||||||
|
|
||||||
|
result = await rerank_chunks("query", [_make_chunk()])
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Object preservation tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_preserves_objects(mock_voyageai):
|
||||||
|
"""Should return the same chunk objects, not copies."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.8),
|
||||||
|
])
|
||||||
|
|
||||||
|
chunk = _make_chunk("test", 0.5)
|
||||||
|
result = await rerank_chunks("query", [chunk])
|
||||||
|
|
||||||
|
assert result[0] is chunk
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("memory.api.search.rerank.voyageai")
|
||||||
|
async def test_rerank_chunks_updates_scores(mock_voyageai):
|
||||||
|
"""Should update chunk relevance_score from reranker."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock_client.rerank.return_value = MockRerankResponse([
|
||||||
|
MockRerankResult(index=0, relevance_score=0.95),
|
||||||
|
])
|
||||||
|
|
||||||
|
chunk = _make_chunk("test", 0.5)
|
||||||
|
result = await rerank_chunks("query", [chunk])
|
||||||
|
|
||||||
|
assert result[0].relevance_score == 0.95
|
||||||
353
tests/memory/api/search/test_search.py
Normal file
353
tests/memory/api/search/test_search.py
Normal file
@ -0,0 +1,353 @@
|
|||||||
|
"""
|
||||||
|
Tests for search module functions including RRF fusion, query term boosting,
|
||||||
|
title boosting, and source deduplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from memory.api.search.search import (
|
||||||
|
extract_query_terms,
|
||||||
|
apply_query_term_boost,
|
||||||
|
deduplicate_by_source,
|
||||||
|
apply_title_boost,
|
||||||
|
fuse_scores_rrf,
|
||||||
|
STOPWORDS,
|
||||||
|
QUERY_TERM_BOOST,
|
||||||
|
TITLE_MATCH_BOOST,
|
||||||
|
RRF_K,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# extract_query_terms tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,expected",
|
||||||
|
[
|
||||||
|
("machine learning algorithms", {"machine", "learning", "algorithms"}),
|
||||||
|
("MACHINE Learning ALGORITHMS", {"machine", "learning", "algorithms"}),
|
||||||
|
("", set()),
|
||||||
|
("the is a an of to", set()), # Only stopwords
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_query_terms_basic(query, expected):
|
||||||
|
"""Should extract meaningful terms, lowercase them, and filter stopwords."""
|
||||||
|
assert extract_query_terms(query) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,must_include,must_exclude",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"the quick brown fox jumps with the lazy dog",
|
||||||
|
{"quick", "brown", "jumps", "lazy", "fox", "dog"},
|
||||||
|
{"the", "with"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"what is the best approach for neural networks",
|
||||||
|
{"best", "approach", "neural", "networks"},
|
||||||
|
{"what", "the", "for"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_query_terms_filtering(query, must_include, must_exclude):
|
||||||
|
"""Should filter stopwords while keeping meaningful terms."""
|
||||||
|
terms = extract_query_terms(query)
|
||||||
|
for term in must_include:
|
||||||
|
assert term in terms, f"'{term}' should be in terms"
|
||||||
|
for term in must_exclude:
|
||||||
|
assert term not in terms, f"'{term}' should not be in terms"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,excluded",
|
||||||
|
[
|
||||||
|
("AI is a new ML model", {"ai", "is", "a", "ml"}), # Short words filtered
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_query_terms_short_words(query, excluded):
|
||||||
|
"""Should filter words with 2 or fewer characters."""
|
||||||
|
terms = extract_query_terms(query)
|
||||||
|
for term in excluded:
|
||||||
|
assert term not in terms
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"word",
|
||||||
|
["the", "is", "are", "was", "were", "be", "been", "have", "has", "had",
|
||||||
|
"do", "does", "did", "to", "of", "in", "for", "on", "with", "at", "by"],
|
||||||
|
)
|
||||||
|
def test_common_stopwords_in_set(word):
|
||||||
|
"""Verify common stopwords are in the STOPWORDS set."""
|
||||||
|
assert word in STOPWORDS
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# apply_query_term_boost tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chunk(content: str, source_id: int = 1, score: float = 0.5):
|
||||||
|
"""Create a mock chunk with given content and score."""
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.content = content
|
||||||
|
chunk.source_id = source_id
|
||||||
|
chunk.relevance_score = score
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content,query_terms,initial_score,expected_boost_fraction",
|
||||||
|
[
|
||||||
|
("machine learning is powerful", {"machine", "learning"}, 0.5, 1.0), # Both match
|
||||||
|
("machine vision systems", {"machine", "learning"}, 0.5, 0.5), # One of two
|
||||||
|
("deep neural networks", {"machine", "learning"}, 0.5, 0.0), # No match
|
||||||
|
("MACHINE Learning AlGoRiThMs", {"machine", "learning", "algorithms"}, 0.5, 1.0), # Case insensitive
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_apply_query_term_boost(content, query_terms, initial_score, expected_boost_fraction):
|
||||||
|
"""Should boost chunks based on query term matches."""
|
||||||
|
chunks = [_make_chunk(content, score=initial_score)]
|
||||||
|
apply_query_term_boost(chunks, query_terms)
|
||||||
|
expected = initial_score + QUERY_TERM_BOOST * expected_boost_fraction
|
||||||
|
assert chunks[0].relevance_score == pytest.approx(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_query_term_boost_empty_inputs():
|
||||||
|
"""Should handle empty query_terms or chunks."""
|
||||||
|
chunks = [_make_chunk("machine learning", score=0.5)]
|
||||||
|
apply_query_term_boost(chunks, set())
|
||||||
|
assert chunks[0].relevance_score == 0.5
|
||||||
|
|
||||||
|
apply_query_term_boost([], {"machine"}) # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_query_term_boost_none_values():
|
||||||
|
"""Should handle None content and relevance_score."""
|
||||||
|
chunk_none_content = MagicMock()
|
||||||
|
chunk_none_content.content = None
|
||||||
|
chunk_none_content.relevance_score = 0.5
|
||||||
|
apply_query_term_boost([chunk_none_content], {"machine"})
|
||||||
|
assert chunk_none_content.relevance_score == 0.5
|
||||||
|
|
||||||
|
chunk_none_score = MagicMock()
|
||||||
|
chunk_none_score.content = "machine learning"
|
||||||
|
chunk_none_score.relevance_score = None
|
||||||
|
apply_query_term_boost([chunk_none_score], {"machine", "learning"})
|
||||||
|
assert chunk_none_score.relevance_score == pytest.approx(QUERY_TERM_BOOST)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_query_term_boost_multiple_chunks():
|
||||||
|
"""Should boost each chunk independently."""
|
||||||
|
chunks = [
|
||||||
|
_make_chunk("machine learning", score=0.5),
|
||||||
|
_make_chunk("deep networks", score=0.6),
|
||||||
|
_make_chunk("machine vision", score=0.4),
|
||||||
|
]
|
||||||
|
query_terms = {"machine", "learning"}
|
||||||
|
apply_query_term_boost(chunks, query_terms)
|
||||||
|
|
||||||
|
assert chunks[0].relevance_score == pytest.approx(0.5 + QUERY_TERM_BOOST)
|
||||||
|
assert chunks[1].relevance_score == 0.6 # No match
|
||||||
|
assert chunks[2].relevance_score == pytest.approx(0.4 + QUERY_TERM_BOOST * 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# deduplicate_by_source tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _make_source_chunk(source_id: int, score: float):
|
||||||
|
"""Create a mock chunk with given source_id and score."""
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.source_id = source_id
|
||||||
|
chunk.relevance_score = score
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"chunks_data,expected_count,expected_scores",
|
||||||
|
[
|
||||||
|
# Multiple chunks per source - keep highest
|
||||||
|
([(1, 0.5), (1, 0.8), (1, 0.3), (2, 0.6)], 2, {1: 0.8, 2: 0.6}),
|
||||||
|
# Single chunk per source - keep all
|
||||||
|
([(1, 0.5), (2, 0.6), (3, 0.7)], 3, {1: 0.5, 2: 0.6, 3: 0.7}),
|
||||||
|
# Empty list
|
||||||
|
([], 0, {}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_deduplicate_by_source(chunks_data, expected_count, expected_scores):
|
||||||
|
"""Should keep only highest scoring chunk per source."""
|
||||||
|
chunks = [_make_source_chunk(sid, score) for sid, score in chunks_data]
|
||||||
|
result = deduplicate_by_source(chunks)
|
||||||
|
|
||||||
|
assert len(result) == expected_count
|
||||||
|
for chunk in result:
|
||||||
|
assert chunk.relevance_score == expected_scores[chunk.source_id]
|
||||||
|
|
||||||
|
|
||||||
|
def test_deduplicate_by_source_preserves_objects():
|
||||||
|
"""Should return the actual chunk objects, not copies."""
|
||||||
|
chunk1 = _make_source_chunk(1, 0.5)
|
||||||
|
chunk2 = _make_source_chunk(1, 0.8)
|
||||||
|
result = deduplicate_by_source([chunk1, chunk2])
|
||||||
|
assert result[0] is chunk2
|
||||||
|
|
||||||
|
|
||||||
|
def test_deduplicate_by_source_none_scores():
|
||||||
|
"""Should handle None relevance_score as 0."""
|
||||||
|
chunk1 = _make_source_chunk(1, None)
|
||||||
|
chunk2 = _make_source_chunk(1, 0.5)
|
||||||
|
result = deduplicate_by_source([chunk1, chunk2])
|
||||||
|
assert result[0].relevance_score == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# apply_title_boost tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _make_title_chunk(source_id: int, score: float = 0.5):
|
||||||
|
"""Create a mock chunk for title boost tests."""
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.source_id = source_id
|
||||||
|
chunk.relevance_score = score
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"title,query_terms,initial_score,expected_boost_fraction",
|
||||||
|
[
|
||||||
|
("Machine Learning Tutorial", {"machine", "learning"}, 0.5, 1.0),
|
||||||
|
("Machine Vision Systems", {"machine", "learning"}, 0.5, 0.5),
|
||||||
|
("Deep Neural Networks", {"machine", "learning"}, 0.5, 0.0),
|
||||||
|
("MACHINE LEARNING Tutorial", {"machine", "learning"}, 0.5, 1.0), # Case insensitive
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.api.search.search.make_session")
|
||||||
|
def test_apply_title_boost(mock_make_session, title, query_terms, initial_score, expected_boost_fraction):
|
||||||
|
"""Should boost chunks when title matches query terms."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.id = 1
|
||||||
|
mock_source.title = title
|
||||||
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
|
chunks = [_make_title_chunk(1, initial_score)]
|
||||||
|
apply_title_boost(chunks, query_terms)
|
||||||
|
|
||||||
|
expected = initial_score + TITLE_MATCH_BOOST * expected_boost_fraction
|
||||||
|
assert chunks[0].relevance_score == pytest.approx(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_title_boost_empty_inputs():
|
||||||
|
"""Should not modify chunks if query_terms or chunks is empty."""
|
||||||
|
chunks = [_make_title_chunk(1, 0.5)]
|
||||||
|
apply_title_boost(chunks, set())
|
||||||
|
assert chunks[0].relevance_score == 0.5
|
||||||
|
|
||||||
|
apply_title_boost([], {"machine"}) # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.search.search.make_session")
|
||||||
|
def test_apply_title_boost_none_title(mock_make_session):
|
||||||
|
"""Should handle sources with None or missing title."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
# Source with None title
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.id = 1
|
||||||
|
mock_source.title = None
|
||||||
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
|
chunks = [_make_title_chunk(1, 0.5)]
|
||||||
|
apply_title_boost(chunks, {"machine"})
|
||||||
|
assert chunks[0].relevance_score == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# fuse_scores_rrf tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"embedding_scores,bm25_scores,expected_key,expected_score",
|
||||||
|
[
|
||||||
|
# Both sources have same ranking
|
||||||
|
({"a": 0.9, "b": 0.7}, {"a": 0.8, "b": 0.6}, "a", 2 / (RRF_K + 1)),
|
||||||
|
# Item only in embeddings
|
||||||
|
({"a": 0.9, "b": 0.7}, {"a": 0.8}, "b", 1 / (RRF_K + 2)),
|
||||||
|
# Item only in BM25
|
||||||
|
({"a": 0.9}, {"a": 0.8, "b": 0.7}, "b", 1 / (RRF_K + 2)),
|
||||||
|
# Single item in both
|
||||||
|
({"a": 0.9}, {"a": 0.8}, "a", 2 / (RRF_K + 1)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_fuse_scores_rrf_basic(embedding_scores, bm25_scores, expected_key, expected_score):
|
||||||
|
"""Should compute RRF scores correctly."""
|
||||||
|
result = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||||
|
assert result[expected_key] == pytest.approx(expected_score)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuse_scores_rrf_different_rankings():
|
||||||
|
"""Should handle items ranked differently in each source."""
|
||||||
|
embedding_scores = {"a": 0.9, "b": 0.5} # a=1, b=2
|
||||||
|
bm25_scores = {"a": 0.3, "b": 0.8} # b=1, a=2
|
||||||
|
|
||||||
|
result = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||||
|
|
||||||
|
# Both should have same RRF score (1/61 + 1/62)
|
||||||
|
expected = 1 / (RRF_K + 1) + 1 / (RRF_K + 2)
|
||||||
|
assert result["a"] == pytest.approx(expected)
|
||||||
|
assert result["b"] == pytest.approx(expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"embedding_scores,bm25_scores,expected_len",
|
||||||
|
[
|
||||||
|
({}, {}, 0),
|
||||||
|
({}, {"a": 0.8, "b": 0.6}, 2),
|
||||||
|
({"a": 0.9, "b": 0.7}, {}, 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_fuse_scores_rrf_empty_inputs(embedding_scores, bm25_scores, expected_len):
|
||||||
|
"""Should handle empty inputs gracefully."""
|
||||||
|
result = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||||
|
assert len(result) == expected_len
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuse_scores_rrf_many_items():
|
||||||
|
"""Should handle many items correctly."""
|
||||||
|
embedding_scores = {str(i): 1.0 - i * 0.01 for i in range(100)}
|
||||||
|
bm25_scores = {str(i): 1.0 - i * 0.01 for i in range(100)}
|
||||||
|
|
||||||
|
result = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||||
|
|
||||||
|
assert len(result) == 100
|
||||||
|
assert result["0"] > result["99"] # First should have highest score
|
||||||
|
|
||||||
|
|
||||||
|
def test_fuse_scores_rrf_only_ranks_matter():
|
||||||
|
"""RRF should only care about ranks, not score magnitudes."""
|
||||||
|
# Same ranking, different score scales
|
||||||
|
result1 = fuse_scores_rrf(
|
||||||
|
{"a": 0.99, "b": 0.98, "c": 0.97},
|
||||||
|
{"a": 100, "b": 50, "c": 1},
|
||||||
|
)
|
||||||
|
result2 = fuse_scores_rrf(
|
||||||
|
{"a": 0.5, "b": 0.4, "c": 0.3},
|
||||||
|
{"a": 0.9, "b": 0.8, "c": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
|
# RRF scores should be identical since rankings are the same
|
||||||
|
assert result1["a"] == pytest.approx(result2["a"])
|
||||||
|
assert result1["b"] == pytest.approx(result2["b"])
|
||||||
|
assert result1["c"] == pytest.approx(result2["c"])
|
||||||
Loading…
x
Reference in New Issue
Block a user