mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
more search improvements
This commit is contained in:
parent
f3d8b6602b
commit
d9fcfe3878
@ -179,7 +179,7 @@ services:
|
||||
VITE_SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
||||
STATIC_DIR: "/app/static"
|
||||
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
||||
ENABLE_BM25_SEARCH: false
|
||||
ENABLE_BM25_SEARCH: true
|
||||
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||
secrets: [postgres_password, openai_key, anthropic_key]
|
||||
|
||||
@ -4,5 +4,4 @@ python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
sqladmin==0.20.1
|
||||
mcp==1.10.0
|
||||
bm25s[full]==0.2.13
|
||||
slowapi==0.1.9
|
||||
@ -4,7 +4,10 @@ Search endpoints for the knowledge base API.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import load_only
|
||||
from memory.common import extract, settings
|
||||
@ -51,6 +54,76 @@ TITLE_MATCH_BOOST = 0.01
|
||||
# This gives a small boost to popular items without dominating relevance
|
||||
POPULARITY_BOOST = 0.02
|
||||
|
||||
# Recency boost settings
|
||||
# Maximum bonus for brand new content (additive)
|
||||
RECENCY_BOOST_MAX = 0.005
|
||||
# Half-life in days: content loses half its recency boost every N days
|
||||
RECENCY_HALF_LIFE_DAYS = 90
|
||||
|
||||
# Query expansion: map abbreviations/acronyms to full forms
|
||||
# These help match when users search for "ML" but documents say "machine learning"
|
||||
QUERY_EXPANSIONS: dict[str, list[str]] = {
|
||||
# AI/ML abbreviations
|
||||
"ai": ["artificial intelligence"],
|
||||
"ml": ["machine learning"],
|
||||
"dl": ["deep learning"],
|
||||
"nlp": ["natural language processing"],
|
||||
"cv": ["computer vision"],
|
||||
"rl": ["reinforcement learning"],
|
||||
"llm": ["large language model", "language model"],
|
||||
"gpt": ["generative pretrained transformer", "language model"],
|
||||
"nn": ["neural network"],
|
||||
"cnn": ["convolutional neural network"],
|
||||
"rnn": ["recurrent neural network"],
|
||||
"lstm": ["long short term memory"],
|
||||
"gan": ["generative adversarial network"],
|
||||
"rag": ["retrieval augmented generation"],
|
||||
# Rationality/EA terms
|
||||
"ea": ["effective altruism"],
|
||||
"lw": ["lesswrong", "less wrong"],
|
||||
"gwwc": ["giving what we can"],
|
||||
"agi": ["artificial general intelligence"],
|
||||
"asi": ["artificial superintelligence"],
|
||||
"fai": ["friendly ai", "ai alignment"],
|
||||
"x-risk": ["existential risk"],
|
||||
"xrisk": ["existential risk"],
|
||||
"p(doom)": ["probability of doom", "ai risk"],
|
||||
# Reverse mappings (full forms -> abbreviations)
|
||||
"artificial intelligence": ["ai"],
|
||||
"machine learning": ["ml"],
|
||||
"deep learning": ["dl"],
|
||||
"natural language processing": ["nlp"],
|
||||
"computer vision": ["cv"],
|
||||
"reinforcement learning": ["rl"],
|
||||
"neural network": ["nn"],
|
||||
"effective altruism": ["ea"],
|
||||
"existential risk": ["x-risk", "xrisk"],
|
||||
}
|
||||
|
||||
|
||||
def expand_query(query: str) -> str:
|
||||
"""
|
||||
Expand query with synonyms and abbreviations.
|
||||
|
||||
This helps match documents that use different terminology for the same concept.
|
||||
For example, "ML algorithms" -> "ML machine learning algorithms"
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
expansions = []
|
||||
|
||||
for term, synonyms in QUERY_EXPANSIONS.items():
|
||||
# Check if term appears as a whole word in the query
|
||||
# Use word boundaries to avoid matching partial words
|
||||
pattern = r'\b' + re.escape(term) + r'\b'
|
||||
if re.search(pattern, query_lower):
|
||||
expansions.extend(synonyms)
|
||||
|
||||
if expansions:
|
||||
# Add expansions to the original query
|
||||
return query + " " + " ".join(expansions)
|
||||
return query
|
||||
|
||||
|
||||
# Common words to ignore when checking for query term presence
|
||||
STOPWORDS = {
|
||||
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
|
||||
@ -116,66 +189,78 @@ def deduplicate_by_source(chunks: list[Chunk]) -> list[Chunk]:
|
||||
return list(best_by_source.values())
|
||||
|
||||
|
||||
def apply_title_boost(
|
||||
def apply_source_boosts(
|
||||
chunks: list[Chunk],
|
||||
query_terms: set[str],
|
||||
) -> None:
|
||||
"""
|
||||
Boost chunks when query terms match the source title.
|
||||
Apply title, popularity, and recency boosts to chunks in a single DB query.
|
||||
|
||||
Title matches are a strong signal since titles summarize content.
|
||||
"""
|
||||
if not query_terms or not chunks:
|
||||
return
|
||||
|
||||
# Get unique source IDs
|
||||
source_ids = list({chunk.source_id for chunk in chunks})
|
||||
|
||||
# Fetch full source items (polymorphic) to access title attribute
|
||||
with make_session() as db:
|
||||
sources = db.query(SourceItem).filter(
|
||||
SourceItem.id.in_(source_ids)
|
||||
).all()
|
||||
titles = {s.id: (getattr(s, 'title', None) or "").lower() for s in sources}
|
||||
|
||||
# Apply boost to chunks whose source title matches query terms
|
||||
for chunk in chunks:
|
||||
title = titles.get(chunk.source_id, "")
|
||||
if not title:
|
||||
continue
|
||||
|
||||
matches = sum(1 for term in query_terms if term in title)
|
||||
if matches > 0:
|
||||
boost = TITLE_MATCH_BOOST * (matches / len(query_terms))
|
||||
chunk.relevance_score = (chunk.relevance_score or 0) + boost
|
||||
|
||||
|
||||
def apply_popularity_boost(chunks: list[Chunk]) -> None:
|
||||
"""
|
||||
Boost chunks based on source popularity.
|
||||
|
||||
Uses the popularity property from SourceItem subclasses.
|
||||
ForumPost uses karma, others default to 1.0.
|
||||
- Title boost: chunks get boosted when query terms appear in source title
|
||||
- Popularity boost: chunks get boosted based on source karma/popularity
|
||||
- Recency boost: newer content gets a small boost that decays over time
|
||||
"""
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
source_ids = list({chunk.source_id for chunk in chunks})
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Single query to fetch all source metadata
|
||||
with make_session() as db:
|
||||
sources = db.query(SourceItem).filter(
|
||||
SourceItem.id.in_(source_ids)
|
||||
).all()
|
||||
popularity_map = {s.id: s.popularity for s in sources}
|
||||
source_map = {
|
||||
s.id: {
|
||||
"title": (getattr(s, "title", None) or "").lower(),
|
||||
"popularity": s.popularity,
|
||||
"inserted_at": s.inserted_at,
|
||||
}
|
||||
for s in sources
|
||||
}
|
||||
|
||||
for chunk in chunks:
|
||||
popularity = popularity_map.get(chunk.source_id, 1.0)
|
||||
source_data = source_map.get(chunk.source_id, {})
|
||||
score = chunk.relevance_score or 0
|
||||
|
||||
# Apply title boost if query terms match
|
||||
if query_terms:
|
||||
title = source_data.get("title", "")
|
||||
if title:
|
||||
matches = sum(1 for term in query_terms if term in title)
|
||||
if matches > 0:
|
||||
score += TITLE_MATCH_BOOST * (matches / len(query_terms))
|
||||
|
||||
# Apply popularity boost
|
||||
popularity = source_data.get("popularity", 1.0)
|
||||
if popularity != 1.0:
|
||||
# Apply boost: score * (1 + POPULARITY_BOOST * (popularity - 1))
|
||||
# For popularity=2.0: multiplier = 1.02
|
||||
# For popularity=0.5: multiplier = 0.99
|
||||
multiplier = 1.0 + POPULARITY_BOOST * (popularity - 1.0)
|
||||
chunk.relevance_score = (chunk.relevance_score or 0) * multiplier
|
||||
score *= multiplier
|
||||
|
||||
# Apply recency boost (exponential decay with half-life)
|
||||
inserted_at = source_data.get("inserted_at")
|
||||
if inserted_at:
|
||||
# Handle timezone-naive timestamps
|
||||
if inserted_at.tzinfo is None:
|
||||
inserted_at = inserted_at.replace(tzinfo=timezone.utc)
|
||||
age_days = (now - inserted_at).total_seconds() / 86400
|
||||
# Exponential decay: boost = max_boost * 0.5^(age/half_life)
|
||||
decay = math.pow(0.5, age_days / RECENCY_HALF_LIFE_DAYS)
|
||||
score += RECENCY_BOOST_MAX * decay
|
||||
|
||||
chunk.relevance_score = score
|
||||
|
||||
|
||||
# Keep legacy functions for backwards compatibility and testing
|
||||
def apply_title_boost(chunks: list[Chunk], query_terms: set[str]) -> None:
|
||||
"""Legacy function - use apply_source_boosts instead."""
|
||||
apply_source_boosts(chunks, query_terms)
|
||||
|
||||
|
||||
def apply_popularity_boost(chunks: list[Chunk]) -> None:
|
||||
"""Legacy function - use apply_source_boosts instead."""
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
|
||||
def fuse_scores_rrf(
|
||||
@ -242,12 +327,21 @@ async def search_chunks(
|
||||
# This helps find results that rank well in one method but not the other
|
||||
internal_limit = limit * CANDIDATE_MULTIPLIER
|
||||
|
||||
# Extract query text for HyDE expansion
|
||||
search_data = list(data) # Copy to avoid modifying original
|
||||
if settings.ENABLE_HYDE_EXPANSION:
|
||||
# Extract query text and apply synonym/abbreviation expansion
|
||||
query_text = " ".join(
|
||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||
)
|
||||
expanded_query = expand_query(query_text)
|
||||
|
||||
# If query was expanded, use expanded version for search
|
||||
if expanded_query != query_text:
|
||||
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
|
||||
search_data = [extract.DataChunk(data=[expanded_query])]
|
||||
else:
|
||||
search_data = list(data) # Copy to avoid modifying original
|
||||
|
||||
# Apply HyDE expansion if enabled
|
||||
if settings.ENABLE_HYDE_EXPANSION:
|
||||
# Only expand queries with 4+ words (short queries are usually specific enough)
|
||||
if len(query_text.split()) >= 4:
|
||||
try:
|
||||
@ -316,15 +410,15 @@ async def search_chunks(
|
||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||
)
|
||||
|
||||
# Apply query term presence boost and title boost
|
||||
# Apply query term presence boost
|
||||
if chunks and query_text.strip():
|
||||
query_terms = extract_query_terms(query_text)
|
||||
apply_query_term_boost(chunks, query_terms)
|
||||
apply_title_boost(chunks, query_terms)
|
||||
|
||||
# Apply popularity boost (karma-based for forum posts)
|
||||
if chunks:
|
||||
apply_popularity_boost(chunks)
|
||||
# Apply title + popularity boosts (single DB query)
|
||||
apply_source_boosts(chunks, query_terms)
|
||||
elif chunks:
|
||||
# No query terms, just apply popularity boost
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
# Rerank using cross-encoder for better precision
|
||||
if settings.ENABLE_RERANKING and chunks and query_text.strip():
|
||||
|
||||
@ -178,6 +178,8 @@ ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
||||
ENABLE_HYDE_EXPANSION = boolean_env("ENABLE_HYDE_EXPANSION", True)
|
||||
HYDE_TIMEOUT = float(os.getenv("HYDE_TIMEOUT", "3.0"))
|
||||
ENABLE_RERANKING = boolean_env("ENABLE_RERANKING", True)
|
||||
RERANK_MODEL = os.getenv("RERANK_MODEL", "rerank-2-lite")
|
||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||
|
||||
|
||||
@ -6,17 +6,23 @@ title boosting, and source deduplication.
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from memory.api.search.search import (
|
||||
extract_query_terms,
|
||||
apply_query_term_boost,
|
||||
deduplicate_by_source,
|
||||
apply_title_boost,
|
||||
apply_popularity_boost,
|
||||
apply_source_boosts,
|
||||
expand_query,
|
||||
fuse_scores_rrf,
|
||||
STOPWORDS,
|
||||
QUERY_TERM_BOOST,
|
||||
TITLE_MATCH_BOOST,
|
||||
POPULARITY_BOOST,
|
||||
RECENCY_BOOST_MAX,
|
||||
RECENCY_HALF_LIFE_DAYS,
|
||||
RRF_K,
|
||||
)
|
||||
|
||||
@ -239,6 +245,8 @@ def test_apply_title_boost(mock_make_session, title, query_terms, initial_score,
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = title
|
||||
mock_source.popularity = 1.0 # Default popularity, no boost
|
||||
mock_source.inserted_at = None # No recency boost
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_title_chunk(1, initial_score)]
|
||||
@ -268,6 +276,8 @@ def test_apply_title_boost_none_title(mock_make_session):
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = None
|
||||
mock_source.popularity = 1.0 # Default popularity, no boost
|
||||
mock_source.inserted_at = None # No recency boost
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_title_chunk(1, 0.5)]
|
||||
@ -307,6 +317,7 @@ def test_apply_popularity_boost(mock_make_session, popularity, initial_score, ex
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.popularity = popularity
|
||||
mock_source.inserted_at = None # No recency boost
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_pop_chunk(1, initial_score)]
|
||||
@ -331,9 +342,11 @@ def test_apply_popularity_boost_multiple_sources(mock_make_session):
|
||||
source1 = MagicMock()
|
||||
source1.id = 1
|
||||
source1.popularity = 2.0 # High karma
|
||||
source1.inserted_at = None # No recency boost
|
||||
source2 = MagicMock()
|
||||
source2.id = 2
|
||||
source2.popularity = 1.0 # Default
|
||||
source2.inserted_at = None # No recency boost
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [source1, source2]
|
||||
|
||||
chunks = [_make_pop_chunk(1, 0.5), _make_pop_chunk(2, 0.5)]
|
||||
@ -423,3 +436,233 @@ def test_fuse_scores_rrf_only_ranks_matter():
|
||||
assert result1["a"] == pytest.approx(result2["a"])
|
||||
assert result1["b"] == pytest.approx(result2["b"])
|
||||
assert result1["c"] == pytest.approx(result2["c"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# apply_source_boosts recency tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_recency_chunk(source_id: int, score: float = 0.5):
|
||||
"""Create a mock chunk for recency boost tests."""
|
||||
chunk = MagicMock()
|
||||
chunk.source_id = source_id
|
||||
chunk.relevance_score = score
|
||||
return chunk
|
||||
|
||||
|
||||
@patch("memory.api.search.search.make_session")
|
||||
def test_recency_boost_new_content(mock_make_session):
|
||||
"""Brand new content should get full recency boost."""
|
||||
mock_session = MagicMock()
|
||||
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = None
|
||||
mock_source.popularity = 1.0
|
||||
mock_source.inserted_at = now # Just inserted
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_recency_chunk(1, 0.5)]
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
# Should get nearly full recency boost
|
||||
expected = 0.5 + RECENCY_BOOST_MAX
|
||||
assert chunks[0].relevance_score == pytest.approx(expected, rel=0.01)
|
||||
|
||||
|
||||
@patch("memory.api.search.search.make_session")
|
||||
def test_recency_boost_half_life_decay(mock_make_session):
|
||||
"""Content at half-life age should get half the boost."""
|
||||
mock_session = MagicMock()
|
||||
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = None
|
||||
mock_source.popularity = 1.0
|
||||
mock_source.inserted_at = now - timedelta(days=RECENCY_HALF_LIFE_DAYS)
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_recency_chunk(1, 0.5)]
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
# Should get half the recency boost
|
||||
expected = 0.5 + RECENCY_BOOST_MAX * 0.5
|
||||
assert chunks[0].relevance_score == pytest.approx(expected, rel=0.01)
|
||||
|
||||
|
||||
@patch("memory.api.search.search.make_session")
|
||||
def test_recency_boost_old_content(mock_make_session):
|
||||
"""Very old content should get minimal recency boost."""
|
||||
mock_session = MagicMock()
|
||||
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = None
|
||||
mock_source.popularity = 1.0
|
||||
mock_source.inserted_at = now - timedelta(days=365) # 1 year old
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_recency_chunk(1, 0.5)]
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
# Should get very little boost (about 0.5^4 ≈ 0.0625 of max)
|
||||
assert chunks[0].relevance_score > 0.5
|
||||
assert chunks[0].relevance_score < 0.5 + RECENCY_BOOST_MAX * 0.1
|
||||
|
||||
|
||||
@patch("memory.api.search.search.make_session")
|
||||
def test_recency_boost_none_timestamp(mock_make_session):
|
||||
"""Should handle None inserted_at gracefully."""
|
||||
mock_session = MagicMock()
|
||||
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = None
|
||||
mock_source.popularity = 1.0
|
||||
mock_source.inserted_at = None
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_recency_chunk(1, 0.5)]
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
# No recency boost applied
|
||||
assert chunks[0].relevance_score == 0.5
|
||||
|
||||
|
||||
@patch("memory.api.search.search.make_session")
|
||||
def test_recency_boost_timezone_naive(mock_make_session):
|
||||
"""Should handle timezone-naive timestamps."""
|
||||
mock_session = MagicMock()
|
||||
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Timezone-naive timestamp
|
||||
naive_dt = datetime.now().replace(tzinfo=None)
|
||||
mock_source = MagicMock()
|
||||
mock_source.id = 1
|
||||
mock_source.title = None
|
||||
mock_source.popularity = 1.0
|
||||
mock_source.inserted_at = naive_dt
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||
|
||||
chunks = [_make_recency_chunk(1, 0.5)]
|
||||
apply_source_boosts(chunks, set()) # Should not raise
|
||||
|
||||
# Should get nearly full boost since it's very recent
|
||||
assert chunks[0].relevance_score > 0.5
|
||||
|
||||
|
||||
@patch("memory.api.search.search.make_session")
|
||||
def test_recency_boost_ordering(mock_make_session):
|
||||
"""Newer content should rank higher than older content."""
|
||||
mock_session = MagicMock()
|
||||
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
source_new = MagicMock()
|
||||
source_new.id = 1
|
||||
source_new.title = None
|
||||
source_new.popularity = 1.0
|
||||
source_new.inserted_at = now - timedelta(days=1)
|
||||
|
||||
source_old = MagicMock()
|
||||
source_old.id = 2
|
||||
source_old.title = None
|
||||
source_old.popularity = 1.0
|
||||
source_old.inserted_at = now - timedelta(days=180)
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [source_new, source_old]
|
||||
|
||||
chunks = [_make_recency_chunk(1, 0.5), _make_recency_chunk(2, 0.5)]
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
# Newer content should have higher score
|
||||
assert chunks[0].relevance_score > chunks[1].relevance_score
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# expand_query tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,expected_expansion",
|
||||
[
|
||||
# AI/ML abbreviations
|
||||
("ML algorithms", "artificial intelligence"), # Not "ML" -> won't have AI
|
||||
("AI safety research", "artificial intelligence"),
|
||||
("NLP models for text", "natural language processing"),
|
||||
("deep learning vs DL", "deep learning"),
|
||||
# Rationality/EA terms
|
||||
("EA organizations", "effective altruism"),
|
||||
("AGI timeline predictions", "artificial general intelligence"),
|
||||
("x-risk reduction", "existential risk"),
|
||||
# Reverse mappings
|
||||
("machine learning basics", "ml"),
|
||||
("artificial intelligence ethics", "ai"),
|
||||
],
|
||||
)
|
||||
def test_expand_query_adds_synonyms(query, expected_expansion):
|
||||
"""Should expand abbreviations and add synonyms."""
|
||||
result = expand_query(query)
|
||||
assert expected_expansion in result.lower()
|
||||
assert query in result # Original query preserved
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
"hello world", # No expansions
|
||||
"python programming", # No expansions
|
||||
"database optimization", # No expansions
|
||||
],
|
||||
)
|
||||
def test_expand_query_no_match(query):
|
||||
"""Should return original query when no expansions match."""
|
||||
result = expand_query(query)
|
||||
assert result == query
|
||||
|
||||
|
||||
def test_expand_query_case_insensitive():
|
||||
"""Should match terms regardless of case."""
|
||||
assert "artificial intelligence" in expand_query("AI research").lower()
|
||||
assert "artificial intelligence" in expand_query("ai research").lower()
|
||||
assert "artificial intelligence" in expand_query("Ai Research").lower()
|
||||
|
||||
|
||||
def test_expand_query_word_boundaries():
|
||||
"""Should only match whole words, not partial matches."""
|
||||
# "mail" contains "ai" but shouldn't trigger expansion
|
||||
result = expand_query("email server")
|
||||
assert result == "email server"
|
||||
|
||||
# "claim" contains "ai" but shouldn't trigger expansion
|
||||
result = expand_query("insurance claim")
|
||||
assert result == "insurance claim"
|
||||
|
||||
|
||||
def test_expand_query_multiple_terms():
|
||||
"""Should expand multiple matching terms."""
|
||||
result = expand_query("AI and ML applications")
|
||||
assert "artificial intelligence" in result.lower()
|
||||
assert "machine learning" in result.lower()
|
||||
|
||||
|
||||
def test_expand_query_preserves_original():
|
||||
"""Should preserve original query text."""
|
||||
original = "AI safety research"
|
||||
result = expand_query(original)
|
||||
assert result.startswith(original)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user