more search improvements

This commit is contained in:
mruwnik 2025-12-21 12:29:27 +00:00
parent f3d8b6602b
commit d9fcfe3878
5 changed files with 392 additions and 54 deletions

View File

@ -179,7 +179,7 @@ services:
VITE_SERVER_URL: "${SERVER_URL:-http://localhost:8000}" VITE_SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
STATIC_DIR: "/app/static" STATIC_DIR: "/app/static"
VOYAGE_API_KEY: ${VOYAGE_API_KEY} VOYAGE_API_KEY: ${VOYAGE_API_KEY}
ENABLE_BM25_SEARCH: false ENABLE_BM25_SEARCH: true
OPENAI_API_KEY_FILE: /run/secrets/openai_key OPENAI_API_KEY_FILE: /run/secrets/openai_key
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
secrets: [postgres_password, openai_key, anthropic_key] secrets: [postgres_password, openai_key, anthropic_key]

View File

@ -4,5 +4,4 @@ python-jose==3.3.0
python-multipart==0.0.9 python-multipart==0.0.9
sqladmin==0.20.1 sqladmin==0.20.1
mcp==1.10.0 mcp==1.10.0
bm25s[full]==0.2.13
slowapi==0.1.9 slowapi==0.1.9

View File

@ -4,7 +4,10 @@ Search endpoints for the knowledge base API.
import asyncio import asyncio
import logging import logging
import math
import re
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone
from typing import Optional from typing import Optional
from sqlalchemy.orm import load_only from sqlalchemy.orm import load_only
from memory.common import extract, settings from memory.common import extract, settings
@ -51,6 +54,76 @@ TITLE_MATCH_BOOST = 0.01
# This gives a small boost to popular items without dominating relevance # This gives a small boost to popular items without dominating relevance
POPULARITY_BOOST = 0.02 POPULARITY_BOOST = 0.02
# Recency boost settings
# Maximum bonus for brand new content (additive)
RECENCY_BOOST_MAX = 0.005
# Half-life in days: content loses half its recency boost every N days
RECENCY_HALF_LIFE_DAYS = 90
# Query expansion: map abbreviations/acronyms to full forms
# These help match when users search for "ML" but documents say "machine learning"
QUERY_EXPANSIONS: dict[str, list[str]] = {
# AI/ML abbreviations
"ai": ["artificial intelligence"],
"ml": ["machine learning"],
"dl": ["deep learning"],
"nlp": ["natural language processing"],
"cv": ["computer vision"],
"rl": ["reinforcement learning"],
"llm": ["large language model", "language model"],
"gpt": ["generative pretrained transformer", "language model"],
"nn": ["neural network"],
"cnn": ["convolutional neural network"],
"rnn": ["recurrent neural network"],
"lstm": ["long short term memory"],
"gan": ["generative adversarial network"],
"rag": ["retrieval augmented generation"],
# Rationality/EA terms
"ea": ["effective altruism"],
"lw": ["lesswrong", "less wrong"],
"gwwc": ["giving what we can"],
"agi": ["artificial general intelligence"],
"asi": ["artificial superintelligence"],
"fai": ["friendly ai", "ai alignment"],
"x-risk": ["existential risk"],
"xrisk": ["existential risk"],
"p(doom)": ["probability of doom", "ai risk"],
# Reverse mappings (full forms -> abbreviations)
"artificial intelligence": ["ai"],
"machine learning": ["ml"],
"deep learning": ["dl"],
"natural language processing": ["nlp"],
"computer vision": ["cv"],
"reinforcement learning": ["rl"],
"neural network": ["nn"],
"effective altruism": ["ea"],
"existential risk": ["x-risk", "xrisk"],
}
def expand_query(query: str) -> str:
"""
Expand query with synonyms and abbreviations.
This helps match documents that use different terminology for the same concept.
For example, "ML algorithms" -> "ML machine learning algorithms"
"""
query_lower = query.lower()
expansions = []
for term, synonyms in QUERY_EXPANSIONS.items():
# Check if term appears as a whole word in the query
# Use word boundaries to avoid matching partial words
pattern = r'\b' + re.escape(term) + r'\b'
if re.search(pattern, query_lower):
expansions.extend(synonyms)
if expansions:
# Add expansions to the original query
return query + " " + " ".join(expansions)
return query
# Common words to ignore when checking for query term presence # Common words to ignore when checking for query term presence
STOPWORDS = { STOPWORDS = {
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
@ -116,66 +189,78 @@ def deduplicate_by_source(chunks: list[Chunk]) -> list[Chunk]:
return list(best_by_source.values()) return list(best_by_source.values())
def apply_title_boost( def apply_source_boosts(
chunks: list[Chunk], chunks: list[Chunk],
query_terms: set[str], query_terms: set[str],
) -> None: ) -> None:
""" """
Boost chunks when query terms match the source title. Apply title, popularity, and recency boosts to chunks in a single DB query.
Title matches are a strong signal since titles summarize content. - Title boost: chunks get boosted when query terms appear in source title
""" - Popularity boost: chunks get boosted based on source karma/popularity
if not query_terms or not chunks: - Recency boost: newer content gets a small boost that decays over time
return
# Get unique source IDs
source_ids = list({chunk.source_id for chunk in chunks})
# Fetch full source items (polymorphic) to access title attribute
with make_session() as db:
sources = db.query(SourceItem).filter(
SourceItem.id.in_(source_ids)
).all()
titles = {s.id: (getattr(s, 'title', None) or "").lower() for s in sources}
# Apply boost to chunks whose source title matches query terms
for chunk in chunks:
title = titles.get(chunk.source_id, "")
if not title:
continue
matches = sum(1 for term in query_terms if term in title)
if matches > 0:
boost = TITLE_MATCH_BOOST * (matches / len(query_terms))
chunk.relevance_score = (chunk.relevance_score or 0) + boost
def apply_popularity_boost(chunks: list[Chunk]) -> None:
"""
Boost chunks based on source popularity.
Uses the popularity property from SourceItem subclasses.
ForumPost uses karma, others default to 1.0.
""" """
if not chunks: if not chunks:
return return
source_ids = list({chunk.source_id for chunk in chunks}) source_ids = list({chunk.source_id for chunk in chunks})
now = datetime.now(timezone.utc)
# Single query to fetch all source metadata
with make_session() as db: with make_session() as db:
sources = db.query(SourceItem).filter( sources = db.query(SourceItem).filter(
SourceItem.id.in_(source_ids) SourceItem.id.in_(source_ids)
).all() ).all()
popularity_map = {s.id: s.popularity for s in sources} source_map = {
s.id: {
"title": (getattr(s, "title", None) or "").lower(),
"popularity": s.popularity,
"inserted_at": s.inserted_at,
}
for s in sources
}
for chunk in chunks: for chunk in chunks:
popularity = popularity_map.get(chunk.source_id, 1.0) source_data = source_map.get(chunk.source_id, {})
score = chunk.relevance_score or 0
# Apply title boost if query terms match
if query_terms:
title = source_data.get("title", "")
if title:
matches = sum(1 for term in query_terms if term in title)
if matches > 0:
score += TITLE_MATCH_BOOST * (matches / len(query_terms))
# Apply popularity boost
popularity = source_data.get("popularity", 1.0)
if popularity != 1.0: if popularity != 1.0:
# Apply boost: score * (1 + POPULARITY_BOOST * (popularity - 1))
# For popularity=2.0: multiplier = 1.02
# For popularity=0.5: multiplier = 0.99
multiplier = 1.0 + POPULARITY_BOOST * (popularity - 1.0) multiplier = 1.0 + POPULARITY_BOOST * (popularity - 1.0)
chunk.relevance_score = (chunk.relevance_score or 0) * multiplier score *= multiplier
# Apply recency boost (exponential decay with half-life)
inserted_at = source_data.get("inserted_at")
if inserted_at:
# Handle timezone-naive timestamps
if inserted_at.tzinfo is None:
inserted_at = inserted_at.replace(tzinfo=timezone.utc)
age_days = (now - inserted_at).total_seconds() / 86400
# Exponential decay: boost = max_boost * 0.5^(age/half_life)
decay = math.pow(0.5, age_days / RECENCY_HALF_LIFE_DAYS)
score += RECENCY_BOOST_MAX * decay
chunk.relevance_score = score
# Keep legacy functions for backwards compatibility and testing
def apply_title_boost(chunks: list[Chunk], query_terms: set[str]) -> None:
"""Legacy function - use apply_source_boosts instead."""
apply_source_boosts(chunks, query_terms)
def apply_popularity_boost(chunks: list[Chunk]) -> None:
"""Legacy function - use apply_source_boosts instead."""
apply_source_boosts(chunks, set())
def fuse_scores_rrf( def fuse_scores_rrf(
@ -242,12 +327,21 @@ async def search_chunks(
# This helps find results that rank well in one method but not the other # This helps find results that rank well in one method but not the other
internal_limit = limit * CANDIDATE_MULTIPLIER internal_limit = limit * CANDIDATE_MULTIPLIER
# Extract query text for HyDE expansion # Extract query text and apply synonym/abbreviation expansion
search_data = list(data) # Copy to avoid modifying original
if settings.ENABLE_HYDE_EXPANSION:
query_text = " ".join( query_text = " ".join(
c for chunk in data for c in chunk.data if isinstance(c, str) c for chunk in data for c in chunk.data if isinstance(c, str)
) )
expanded_query = expand_query(query_text)
# If query was expanded, use expanded version for search
if expanded_query != query_text:
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
search_data = [extract.DataChunk(data=[expanded_query])]
else:
search_data = list(data) # Copy to avoid modifying original
# Apply HyDE expansion if enabled
if settings.ENABLE_HYDE_EXPANSION:
# Only expand queries with 4+ words (short queries are usually specific enough) # Only expand queries with 4+ words (short queries are usually specific enough)
if len(query_text.split()) >= 4: if len(query_text.split()) >= 4:
try: try:
@ -316,15 +410,15 @@ async def search_chunks(
c for chunk in data for c in chunk.data if isinstance(c, str) c for chunk in data for c in chunk.data if isinstance(c, str)
) )
# Apply query term presence boost and title boost # Apply query term presence boost
if chunks and query_text.strip(): if chunks and query_text.strip():
query_terms = extract_query_terms(query_text) query_terms = extract_query_terms(query_text)
apply_query_term_boost(chunks, query_terms) apply_query_term_boost(chunks, query_terms)
apply_title_boost(chunks, query_terms) # Apply title + popularity boosts (single DB query)
apply_source_boosts(chunks, query_terms)
# Apply popularity boost (karma-based for forum posts) elif chunks:
if chunks: # No query terms, just apply popularity boost
apply_popularity_boost(chunks) apply_source_boosts(chunks, set())
# Rerank using cross-encoder for better precision # Rerank using cross-encoder for better precision
if settings.ENABLE_RERANKING and chunks and query_text.strip(): if settings.ENABLE_RERANKING and chunks and query_text.strip():

View File

@ -178,6 +178,8 @@ ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True) ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
ENABLE_HYDE_EXPANSION = boolean_env("ENABLE_HYDE_EXPANSION", True) ENABLE_HYDE_EXPANSION = boolean_env("ENABLE_HYDE_EXPANSION", True)
HYDE_TIMEOUT = float(os.getenv("HYDE_TIMEOUT", "3.0")) HYDE_TIMEOUT = float(os.getenv("HYDE_TIMEOUT", "3.0"))
ENABLE_RERANKING = boolean_env("ENABLE_RERANKING", True)
RERANK_MODEL = os.getenv("RERANK_MODEL", "rerank-2-lite")
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16)) MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000)) MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))

View File

@ -6,17 +6,23 @@ title boosting, and source deduplication.
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from datetime import datetime, timedelta, timezone
from memory.api.search.search import ( from memory.api.search.search import (
extract_query_terms, extract_query_terms,
apply_query_term_boost, apply_query_term_boost,
deduplicate_by_source, deduplicate_by_source,
apply_title_boost, apply_title_boost,
apply_popularity_boost, apply_popularity_boost,
apply_source_boosts,
expand_query,
fuse_scores_rrf, fuse_scores_rrf,
STOPWORDS, STOPWORDS,
QUERY_TERM_BOOST, QUERY_TERM_BOOST,
TITLE_MATCH_BOOST, TITLE_MATCH_BOOST,
POPULARITY_BOOST, POPULARITY_BOOST,
RECENCY_BOOST_MAX,
RECENCY_HALF_LIFE_DAYS,
RRF_K, RRF_K,
) )
@ -239,6 +245,8 @@ def test_apply_title_boost(mock_make_session, title, query_terms, initial_score,
mock_source = MagicMock() mock_source = MagicMock()
mock_source.id = 1 mock_source.id = 1
mock_source.title = title mock_source.title = title
mock_source.popularity = 1.0 # Default popularity, no boost
mock_source.inserted_at = None # No recency boost
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_title_chunk(1, initial_score)] chunks = [_make_title_chunk(1, initial_score)]
@ -268,6 +276,8 @@ def test_apply_title_boost_none_title(mock_make_session):
mock_source = MagicMock() mock_source = MagicMock()
mock_source.id = 1 mock_source.id = 1
mock_source.title = None mock_source.title = None
mock_source.popularity = 1.0 # Default popularity, no boost
mock_source.inserted_at = None # No recency boost
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_title_chunk(1, 0.5)] chunks = [_make_title_chunk(1, 0.5)]
@ -307,6 +317,7 @@ def test_apply_popularity_boost(mock_make_session, popularity, initial_score, ex
mock_source = MagicMock() mock_source = MagicMock()
mock_source.id = 1 mock_source.id = 1
mock_source.popularity = popularity mock_source.popularity = popularity
mock_source.inserted_at = None # No recency boost
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source] mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_pop_chunk(1, initial_score)] chunks = [_make_pop_chunk(1, initial_score)]
@ -331,9 +342,11 @@ def test_apply_popularity_boost_multiple_sources(mock_make_session):
source1 = MagicMock() source1 = MagicMock()
source1.id = 1 source1.id = 1
source1.popularity = 2.0 # High karma source1.popularity = 2.0 # High karma
source1.inserted_at = None # No recency boost
source2 = MagicMock() source2 = MagicMock()
source2.id = 2 source2.id = 2
source2.popularity = 1.0 # Default source2.popularity = 1.0 # Default
source2.inserted_at = None # No recency boost
mock_session.query.return_value.filter.return_value.all.return_value = [source1, source2] mock_session.query.return_value.filter.return_value.all.return_value = [source1, source2]
chunks = [_make_pop_chunk(1, 0.5), _make_pop_chunk(2, 0.5)] chunks = [_make_pop_chunk(1, 0.5), _make_pop_chunk(2, 0.5)]
@ -423,3 +436,233 @@ def test_fuse_scores_rrf_only_ranks_matter():
assert result1["a"] == pytest.approx(result2["a"]) assert result1["a"] == pytest.approx(result2["a"])
assert result1["b"] == pytest.approx(result2["b"]) assert result1["b"] == pytest.approx(result2["b"])
assert result1["c"] == pytest.approx(result2["c"]) assert result1["c"] == pytest.approx(result2["c"])
# ============================================================================
# apply_source_boosts recency tests
# ============================================================================
def _make_recency_chunk(source_id: int, score: float = 0.5):
"""Create a mock chunk for recency boost tests."""
chunk = MagicMock()
chunk.source_id = source_id
chunk.relevance_score = score
return chunk
@patch("memory.api.search.search.make_session")
def test_recency_boost_new_content(mock_make_session):
"""Brand new content should get full recency boost."""
mock_session = MagicMock()
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
now = datetime.now(timezone.utc)
mock_source = MagicMock()
mock_source.id = 1
mock_source.title = None
mock_source.popularity = 1.0
mock_source.inserted_at = now # Just inserted
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_recency_chunk(1, 0.5)]
apply_source_boosts(chunks, set())
# Should get nearly full recency boost
expected = 0.5 + RECENCY_BOOST_MAX
assert chunks[0].relevance_score == pytest.approx(expected, rel=0.01)
@patch("memory.api.search.search.make_session")
def test_recency_boost_half_life_decay(mock_make_session):
"""Content at half-life age should get half the boost."""
mock_session = MagicMock()
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
now = datetime.now(timezone.utc)
mock_source = MagicMock()
mock_source.id = 1
mock_source.title = None
mock_source.popularity = 1.0
mock_source.inserted_at = now - timedelta(days=RECENCY_HALF_LIFE_DAYS)
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_recency_chunk(1, 0.5)]
apply_source_boosts(chunks, set())
# Should get half the recency boost
expected = 0.5 + RECENCY_BOOST_MAX * 0.5
assert chunks[0].relevance_score == pytest.approx(expected, rel=0.01)
@patch("memory.api.search.search.make_session")
def test_recency_boost_old_content(mock_make_session):
"""Very old content should get minimal recency boost."""
mock_session = MagicMock()
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
now = datetime.now(timezone.utc)
mock_source = MagicMock()
mock_source.id = 1
mock_source.title = None
mock_source.popularity = 1.0
mock_source.inserted_at = now - timedelta(days=365) # 1 year old
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_recency_chunk(1, 0.5)]
apply_source_boosts(chunks, set())
# Should get very little boost (about 0.5^4 ≈ 0.0625 of max)
assert chunks[0].relevance_score > 0.5
assert chunks[0].relevance_score < 0.5 + RECENCY_BOOST_MAX * 0.1
@patch("memory.api.search.search.make_session")
def test_recency_boost_none_timestamp(mock_make_session):
"""Should handle None inserted_at gracefully."""
mock_session = MagicMock()
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
mock_source = MagicMock()
mock_source.id = 1
mock_source.title = None
mock_source.popularity = 1.0
mock_source.inserted_at = None
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_recency_chunk(1, 0.5)]
apply_source_boosts(chunks, set())
# No recency boost applied
assert chunks[0].relevance_score == 0.5
@patch("memory.api.search.search.make_session")
def test_recency_boost_timezone_naive(mock_make_session):
"""Should handle timezone-naive timestamps."""
mock_session = MagicMock()
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
# Timezone-naive timestamp
naive_dt = datetime.now().replace(tzinfo=None)
mock_source = MagicMock()
mock_source.id = 1
mock_source.title = None
mock_source.popularity = 1.0
mock_source.inserted_at = naive_dt
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
chunks = [_make_recency_chunk(1, 0.5)]
apply_source_boosts(chunks, set()) # Should not raise
# Should get nearly full boost since it's very recent
assert chunks[0].relevance_score > 0.5
@patch("memory.api.search.search.make_session")
def test_recency_boost_ordering(mock_make_session):
"""Newer content should rank higher than older content."""
mock_session = MagicMock()
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
now = datetime.now(timezone.utc)
source_new = MagicMock()
source_new.id = 1
source_new.title = None
source_new.popularity = 1.0
source_new.inserted_at = now - timedelta(days=1)
source_old = MagicMock()
source_old.id = 2
source_old.title = None
source_old.popularity = 1.0
source_old.inserted_at = now - timedelta(days=180)
mock_session.query.return_value.filter.return_value.all.return_value = [source_new, source_old]
chunks = [_make_recency_chunk(1, 0.5), _make_recency_chunk(2, 0.5)]
apply_source_boosts(chunks, set())
# Newer content should have higher score
assert chunks[0].relevance_score > chunks[1].relevance_score
# ============================================================================
# expand_query tests
# ============================================================================
@pytest.mark.parametrize(
"query,expected_expansion",
[
# AI/ML abbreviations
("ML algorithms", "artificial intelligence"), # Not "ML" -> won't have AI
("AI safety research", "artificial intelligence"),
("NLP models for text", "natural language processing"),
("deep learning vs DL", "deep learning"),
# Rationality/EA terms
("EA organizations", "effective altruism"),
("AGI timeline predictions", "artificial general intelligence"),
("x-risk reduction", "existential risk"),
# Reverse mappings
("machine learning basics", "ml"),
("artificial intelligence ethics", "ai"),
],
)
def test_expand_query_adds_synonyms(query, expected_expansion):
"""Should expand abbreviations and add synonyms."""
result = expand_query(query)
assert expected_expansion in result.lower()
assert query in result # Original query preserved
@pytest.mark.parametrize(
"query",
[
"hello world", # No expansions
"python programming", # No expansions
"database optimization", # No expansions
],
)
def test_expand_query_no_match(query):
"""Should return original query when no expansions match."""
result = expand_query(query)
assert result == query
def test_expand_query_case_insensitive():
"""Should match terms regardless of case."""
assert "artificial intelligence" in expand_query("AI research").lower()
assert "artificial intelligence" in expand_query("ai research").lower()
assert "artificial intelligence" in expand_query("Ai Research").lower()
def test_expand_query_word_boundaries():
"""Should only match whole words, not partial matches."""
# "mail" contains "ai" but shouldn't trigger expansion
result = expand_query("email server")
assert result == "email server"
# "claim" contains "ai" but shouldn't trigger expansion
result = expand_query("insurance claim")
assert result == "insurance claim"
def test_expand_query_multiple_terms():
"""Should expand multiple matching terms."""
result = expand_query("AI and ML applications")
assert "artificial intelligence" in result.lower()
assert "machine learning" in result.lower()
def test_expand_query_preserves_original():
"""Should preserve original query text."""
original = "AI safety research"
result = expand_query(original)
assert result.startswith(original)