From d9fcfe3878e4484ca34ad2d1984d640cd2f60928 Mon Sep 17 00:00:00 2001 From: mruwnik Date: Sun, 21 Dec 2025 12:29:27 +0000 Subject: [PATCH] more search improvements --- docker-compose.yaml | 2 +- requirements/requirements-api.txt | 1 - src/memory/api/search/search.py | 198 ++++++++++++++------ src/memory/common/settings.py | 2 + tests/memory/api/search/test_search.py | 243 +++++++++++++++++++++++++ 5 files changed, 392 insertions(+), 54 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index b692d1b..fa7409c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -179,7 +179,7 @@ services: VITE_SERVER_URL: "${SERVER_URL:-http://localhost:8000}" STATIC_DIR: "/app/static" VOYAGE_API_KEY: ${VOYAGE_API_KEY} - ENABLE_BM25_SEARCH: false + ENABLE_BM25_SEARCH: true OPENAI_API_KEY_FILE: /run/secrets/openai_key ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key secrets: [postgres_password, openai_key, anthropic_key] diff --git a/requirements/requirements-api.txt b/requirements/requirements-api.txt index 5fb7f1d..3e1a402 100644 --- a/requirements/requirements-api.txt +++ b/requirements/requirements-api.txt @@ -4,5 +4,4 @@ python-jose==3.3.0 python-multipart==0.0.9 sqladmin==0.20.1 mcp==1.10.0 -bm25s[full]==0.2.13 slowapi==0.1.9 \ No newline at end of file diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py index 32b63dc..5912547 100644 --- a/src/memory/api/search/search.py +++ b/src/memory/api/search/search.py @@ -4,7 +4,10 @@ Search endpoints for the knowledge base API. import asyncio import logging +import math +import re from collections import defaultdict +from datetime import datetime, timezone from typing import Optional from sqlalchemy.orm import load_only from memory.common import extract, settings @@ -51,6 +54,76 @@ TITLE_MATCH_BOOST = 0.01 # This gives a small boost to popular items without dominating relevance POPULARITY_BOOST = 0.02 +# Recency boost settings +# Maximum bonus for brand new content (additive) +RECENCY_BOOST_MAX = 0.005 +# Half-life in days: content loses half its recency boost every N days +RECENCY_HALF_LIFE_DAYS = 90 + +# Query expansion: map abbreviations/acronyms to full forms +# These help match when users search for "ML" but documents say "machine learning" +QUERY_EXPANSIONS: dict[str, list[str]] = { + # AI/ML abbreviations + "ai": ["artificial intelligence"], + "ml": ["machine learning"], + "dl": ["deep learning"], + "nlp": ["natural language processing"], + "cv": ["computer vision"], + "rl": ["reinforcement learning"], + "llm": ["large language model", "language model"], + "gpt": ["generative pretrained transformer", "language model"], + "nn": ["neural network"], + "cnn": ["convolutional neural network"], + "rnn": ["recurrent neural network"], + "lstm": ["long short term memory"], + "gan": ["generative adversarial network"], + "rag": ["retrieval augmented generation"], + # Rationality/EA terms + "ea": ["effective altruism"], + "lw": ["lesswrong", "less wrong"], + "gwwc": ["giving what we can"], + "agi": ["artificial general intelligence"], + "asi": ["artificial superintelligence"], + "fai": ["friendly ai", "ai alignment"], + "x-risk": ["existential risk"], + "xrisk": ["existential risk"], + "p(doom)": ["probability of doom", "ai risk"], + # Reverse mappings (full forms -> abbreviations) + "artificial intelligence": ["ai"], + "machine learning": ["ml"], + "deep learning": ["dl"], + "natural language processing": ["nlp"], + "computer vision": ["cv"], + "reinforcement learning": ["rl"], + "neural network": ["nn"], + "effective altruism": ["ea"], + "existential risk": ["x-risk", "xrisk"], +} + + +def expand_query(query: str) -> str: + """ + Expand query with synonyms and abbreviations. + + This helps match documents that use different terminology for the same concept. + For example, "ML algorithms" -> "ML machine learning algorithms" + """ + query_lower = query.lower() + expansions = [] + + for term, synonyms in QUERY_EXPANSIONS.items(): + # Check if term appears as a whole word in the query + # Use word boundaries to avoid matching partial words + pattern = r'\b' + re.escape(term) + r'\b' + if re.search(pattern, query_lower): + expansions.extend(synonyms) + + if expansions: + # Add expansions to the original query + return query + " " + " ".join(expansions) + return query + + # Common words to ignore when checking for query term presence STOPWORDS = { "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", @@ -116,66 +189,78 @@ def deduplicate_by_source(chunks: list[Chunk]) -> list[Chunk]: return list(best_by_source.values()) -def apply_title_boost( +def apply_source_boosts( chunks: list[Chunk], query_terms: set[str], ) -> None: """ - Boost chunks when query terms match the source title. + Apply title, popularity, and recency boosts to chunks in a single DB query. - Title matches are a strong signal since titles summarize content. - """ - if not query_terms or not chunks: - return - - # Get unique source IDs - source_ids = list({chunk.source_id for chunk in chunks}) - - # Fetch full source items (polymorphic) to access title attribute - with make_session() as db: - sources = db.query(SourceItem).filter( - SourceItem.id.in_(source_ids) - ).all() - titles = {s.id: (getattr(s, 'title', None) or "").lower() for s in sources} - - # Apply boost to chunks whose source title matches query terms - for chunk in chunks: - title = titles.get(chunk.source_id, "") - if not title: - continue - - matches = sum(1 for term in query_terms if term in title) - if matches > 0: - boost = TITLE_MATCH_BOOST * (matches / len(query_terms)) - chunk.relevance_score = (chunk.relevance_score or 0) + boost - - -def apply_popularity_boost(chunks: list[Chunk]) -> None: - """ - Boost chunks based on source popularity. - - Uses the popularity property from SourceItem subclasses. - ForumPost uses karma, others default to 1.0. + - Title boost: chunks get boosted when query terms appear in source title + - Popularity boost: chunks get boosted based on source karma/popularity + - Recency boost: newer content gets a small boost that decays over time """ if not chunks: return source_ids = list({chunk.source_id for chunk in chunks}) + now = datetime.now(timezone.utc) + # Single query to fetch all source metadata with make_session() as db: sources = db.query(SourceItem).filter( SourceItem.id.in_(source_ids) ).all() - popularity_map = {s.id: s.popularity for s in sources} + source_map = { + s.id: { + "title": (getattr(s, "title", None) or "").lower(), + "popularity": s.popularity, + "inserted_at": s.inserted_at, + } + for s in sources + } for chunk in chunks: - popularity = popularity_map.get(chunk.source_id, 1.0) + source_data = source_map.get(chunk.source_id, {}) + score = chunk.relevance_score or 0 + + # Apply title boost if query terms match + if query_terms: + title = source_data.get("title", "") + if title: + matches = sum(1 for term in query_terms if term in title) + if matches > 0: + score += TITLE_MATCH_BOOST * (matches / len(query_terms)) + + # Apply popularity boost + popularity = source_data.get("popularity", 1.0) if popularity != 1.0: - # Apply boost: score * (1 + POPULARITY_BOOST * (popularity - 1)) - # For popularity=2.0: multiplier = 1.02 - # For popularity=0.5: multiplier = 0.99 multiplier = 1.0 + POPULARITY_BOOST * (popularity - 1.0) - chunk.relevance_score = (chunk.relevance_score or 0) * multiplier + score *= multiplier + + # Apply recency boost (exponential decay with half-life) + inserted_at = source_data.get("inserted_at") + if inserted_at: + # Handle timezone-naive timestamps + if inserted_at.tzinfo is None: + inserted_at = inserted_at.replace(tzinfo=timezone.utc) + age_days = (now - inserted_at).total_seconds() / 86400 + # Exponential decay: boost = max_boost * 0.5^(age/half_life) + decay = math.pow(0.5, age_days / RECENCY_HALF_LIFE_DAYS) + score += RECENCY_BOOST_MAX * decay + + chunk.relevance_score = score + + +# Keep legacy functions for backwards compatibility and testing +def apply_title_boost(chunks: list[Chunk], query_terms: set[str]) -> None: + """Legacy function - use apply_source_boosts instead.""" + apply_source_boosts(chunks, query_terms) + + +def apply_popularity_boost(chunks: list[Chunk]) -> None: + """Legacy function - use apply_source_boosts instead.""" + apply_source_boosts(chunks, set()) def fuse_scores_rrf( @@ -242,12 +327,21 @@ async def search_chunks( # This helps find results that rank well in one method but not the other internal_limit = limit * CANDIDATE_MULTIPLIER - # Extract query text for HyDE expansion - search_data = list(data) # Copy to avoid modifying original + # Extract query text and apply synonym/abbreviation expansion + query_text = " ".join( + c for chunk in data for c in chunk.data if isinstance(c, str) + ) + expanded_query = expand_query(query_text) + + # If query was expanded, use expanded version for search + if expanded_query != query_text: + logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'") + search_data = [extract.DataChunk(data=[expanded_query])] + else: + search_data = list(data) # Copy to avoid modifying original + + # Apply HyDE expansion if enabled if settings.ENABLE_HYDE_EXPANSION: - query_text = " ".join( - c for chunk in data for c in chunk.data if isinstance(c, str) - ) # Only expand queries with 4+ words (short queries are usually specific enough) if len(query_text.split()) >= 4: try: @@ -316,15 +410,15 @@ async def search_chunks( c for chunk in data for c in chunk.data if isinstance(c, str) ) - # Apply query term presence boost and title boost + # Apply query term presence boost if chunks and query_text.strip(): query_terms = extract_query_terms(query_text) apply_query_term_boost(chunks, query_terms) - apply_title_boost(chunks, query_terms) - - # Apply popularity boost (karma-based for forum posts) - if chunks: - apply_popularity_boost(chunks) + # Apply title + popularity boosts (single DB query) + apply_source_boosts(chunks, query_terms) + elif chunks: + # No query terms, just apply popularity boost + apply_source_boosts(chunks, set()) # Rerank using cross-encoder for better precision if settings.ENABLE_RERANKING and chunks and query_text.strip(): diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 0b4f457..6eda554 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -178,6 +178,8 @@ ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True) ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True) ENABLE_HYDE_EXPANSION = boolean_env("ENABLE_HYDE_EXPANSION", True) HYDE_TIMEOUT = float(os.getenv("HYDE_TIMEOUT", "3.0")) +ENABLE_RERANKING = boolean_env("ENABLE_RERANKING", True) +RERANK_MODEL = os.getenv("RERANK_MODEL", "rerank-2-lite") MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16)) MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000)) diff --git a/tests/memory/api/search/test_search.py b/tests/memory/api/search/test_search.py index 1fec142..09bfef6 100644 --- a/tests/memory/api/search/test_search.py +++ b/tests/memory/api/search/test_search.py @@ -6,17 +6,23 @@ title boosting, and source deduplication. import pytest from unittest.mock import MagicMock, patch +from datetime import datetime, timedelta, timezone + from memory.api.search.search import ( extract_query_terms, apply_query_term_boost, deduplicate_by_source, apply_title_boost, apply_popularity_boost, + apply_source_boosts, + expand_query, fuse_scores_rrf, STOPWORDS, QUERY_TERM_BOOST, TITLE_MATCH_BOOST, POPULARITY_BOOST, + RECENCY_BOOST_MAX, + RECENCY_HALF_LIFE_DAYS, RRF_K, ) @@ -239,6 +245,8 @@ def test_apply_title_boost(mock_make_session, title, query_terms, initial_score, mock_source = MagicMock() mock_source.id = 1 mock_source.title = title + mock_source.popularity = 1.0 # Default popularity, no boost + mock_source.inserted_at = None # No recency boost mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] chunks = [_make_title_chunk(1, initial_score)] @@ -268,6 +276,8 @@ def test_apply_title_boost_none_title(mock_make_session): mock_source = MagicMock() mock_source.id = 1 mock_source.title = None + mock_source.popularity = 1.0 # Default popularity, no boost + mock_source.inserted_at = None # No recency boost mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] chunks = [_make_title_chunk(1, 0.5)] @@ -307,6 +317,7 @@ def test_apply_popularity_boost(mock_make_session, popularity, initial_score, ex mock_source = MagicMock() mock_source.id = 1 mock_source.popularity = popularity + mock_source.inserted_at = None # No recency boost mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] chunks = [_make_pop_chunk(1, initial_score)] @@ -331,9 +342,11 @@ def test_apply_popularity_boost_multiple_sources(mock_make_session): source1 = MagicMock() source1.id = 1 source1.popularity = 2.0 # High karma + source1.inserted_at = None # No recency boost source2 = MagicMock() source2.id = 2 source2.popularity = 1.0 # Default + source2.inserted_at = None # No recency boost mock_session.query.return_value.filter.return_value.all.return_value = [source1, source2] chunks = [_make_pop_chunk(1, 0.5), _make_pop_chunk(2, 0.5)] @@ -423,3 +436,233 @@ def test_fuse_scores_rrf_only_ranks_matter(): assert result1["a"] == pytest.approx(result2["a"]) assert result1["b"] == pytest.approx(result2["b"]) assert result1["c"] == pytest.approx(result2["c"]) + + +# ============================================================================ +# apply_source_boosts recency tests +# ============================================================================ + + +def _make_recency_chunk(source_id: int, score: float = 0.5): + """Create a mock chunk for recency boost tests.""" + chunk = MagicMock() + chunk.source_id = source_id + chunk.relevance_score = score + return chunk + + +@patch("memory.api.search.search.make_session") +def test_recency_boost_new_content(mock_make_session): + """Brand new content should get full recency boost.""" + mock_session = MagicMock() + mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_make_session.return_value.__exit__ = MagicMock(return_value=None) + + now = datetime.now(timezone.utc) + mock_source = MagicMock() + mock_source.id = 1 + mock_source.title = None + mock_source.popularity = 1.0 + mock_source.inserted_at = now # Just inserted + mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] + + chunks = [_make_recency_chunk(1, 0.5)] + apply_source_boosts(chunks, set()) + + # Should get nearly full recency boost + expected = 0.5 + RECENCY_BOOST_MAX + assert chunks[0].relevance_score == pytest.approx(expected, rel=0.01) + + +@patch("memory.api.search.search.make_session") +def test_recency_boost_half_life_decay(mock_make_session): + """Content at half-life age should get half the boost.""" + mock_session = MagicMock() + mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_make_session.return_value.__exit__ = MagicMock(return_value=None) + + now = datetime.now(timezone.utc) + mock_source = MagicMock() + mock_source.id = 1 + mock_source.title = None + mock_source.popularity = 1.0 + mock_source.inserted_at = now - timedelta(days=RECENCY_HALF_LIFE_DAYS) + mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] + + chunks = [_make_recency_chunk(1, 0.5)] + apply_source_boosts(chunks, set()) + + # Should get half the recency boost + expected = 0.5 + RECENCY_BOOST_MAX * 0.5 + assert chunks[0].relevance_score == pytest.approx(expected, rel=0.01) + + +@patch("memory.api.search.search.make_session") +def test_recency_boost_old_content(mock_make_session): + """Very old content should get minimal recency boost.""" + mock_session = MagicMock() + mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_make_session.return_value.__exit__ = MagicMock(return_value=None) + + now = datetime.now(timezone.utc) + mock_source = MagicMock() + mock_source.id = 1 + mock_source.title = None + mock_source.popularity = 1.0 + mock_source.inserted_at = now - timedelta(days=365) # 1 year old + mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] + + chunks = [_make_recency_chunk(1, 0.5)] + apply_source_boosts(chunks, set()) + + # Should get very little boost (about 0.5^4 ≈ 0.0625 of max) + assert chunks[0].relevance_score > 0.5 + assert chunks[0].relevance_score < 0.5 + RECENCY_BOOST_MAX * 0.1 + + +@patch("memory.api.search.search.make_session") +def test_recency_boost_none_timestamp(mock_make_session): + """Should handle None inserted_at gracefully.""" + mock_session = MagicMock() + mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_make_session.return_value.__exit__ = MagicMock(return_value=None) + + mock_source = MagicMock() + mock_source.id = 1 + mock_source.title = None + mock_source.popularity = 1.0 + mock_source.inserted_at = None + mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] + + chunks = [_make_recency_chunk(1, 0.5)] + apply_source_boosts(chunks, set()) + + # No recency boost applied + assert chunks[0].relevance_score == 0.5 + + +@patch("memory.api.search.search.make_session") +def test_recency_boost_timezone_naive(mock_make_session): + """Should handle timezone-naive timestamps.""" + mock_session = MagicMock() + mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_make_session.return_value.__exit__ = MagicMock(return_value=None) + + # Timezone-naive timestamp + naive_dt = datetime.now().replace(tzinfo=None) + mock_source = MagicMock() + mock_source.id = 1 + mock_source.title = None + mock_source.popularity = 1.0 + mock_source.inserted_at = naive_dt + mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] + + chunks = [_make_recency_chunk(1, 0.5)] + apply_source_boosts(chunks, set()) # Should not raise + + # Should get nearly full boost since it's very recent + assert chunks[0].relevance_score > 0.5 + + +@patch("memory.api.search.search.make_session") +def test_recency_boost_ordering(mock_make_session): + """Newer content should rank higher than older content.""" + mock_session = MagicMock() + mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_make_session.return_value.__exit__ = MagicMock(return_value=None) + + now = datetime.now(timezone.utc) + source_new = MagicMock() + source_new.id = 1 + source_new.title = None + source_new.popularity = 1.0 + source_new.inserted_at = now - timedelta(days=1) + + source_old = MagicMock() + source_old.id = 2 + source_old.title = None + source_old.popularity = 1.0 + source_old.inserted_at = now - timedelta(days=180) + + mock_session.query.return_value.filter.return_value.all.return_value = [source_new, source_old] + + chunks = [_make_recency_chunk(1, 0.5), _make_recency_chunk(2, 0.5)] + apply_source_boosts(chunks, set()) + + # Newer content should have higher score + assert chunks[0].relevance_score > chunks[1].relevance_score + + +# ============================================================================ +# expand_query tests +# ============================================================================ + + +@pytest.mark.parametrize( + "query,expected_expansion", + [ + # AI/ML abbreviations + ("ML algorithms", "artificial intelligence"), # Not "ML" -> won't have AI + ("AI safety research", "artificial intelligence"), + ("NLP models for text", "natural language processing"), + ("deep learning vs DL", "deep learning"), + # Rationality/EA terms + ("EA organizations", "effective altruism"), + ("AGI timeline predictions", "artificial general intelligence"), + ("x-risk reduction", "existential risk"), + # Reverse mappings + ("machine learning basics", "ml"), + ("artificial intelligence ethics", "ai"), + ], +) +def test_expand_query_adds_synonyms(query, expected_expansion): + """Should expand abbreviations and add synonyms.""" + result = expand_query(query) + assert expected_expansion in result.lower() + assert query in result # Original query preserved + + +@pytest.mark.parametrize( + "query", + [ + "hello world", # No expansions + "python programming", # No expansions + "database optimization", # No expansions + ], +) +def test_expand_query_no_match(query): + """Should return original query when no expansions match.""" + result = expand_query(query) + assert result == query + + +def test_expand_query_case_insensitive(): + """Should match terms regardless of case.""" + assert "artificial intelligence" in expand_query("AI research").lower() + assert "artificial intelligence" in expand_query("ai research").lower() + assert "artificial intelligence" in expand_query("Ai Research").lower() + + +def test_expand_query_word_boundaries(): + """Should only match whole words, not partial matches.""" + # "mail" contains "ai" but shouldn't trigger expansion + result = expand_query("email server") + assert result == "email server" + + # "claim" contains "ai" but shouldn't trigger expansion + result = expand_query("insurance claim") + assert result == "insurance claim" + + +def test_expand_query_multiple_terms(): + """Should expand multiple matching terms.""" + result = expand_query("AI and ML applications") + assert "artificial intelligence" in result.lower() + assert "machine learning" in result.lower() + + +def test_expand_query_preserves_original(): + """Should preserve original query text.""" + original = "AI safety research" + result = expand_query(original) + assert result.startswith(original)