mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 17:22:58 +01:00
more search improvements
This commit is contained in:
parent
f3d8b6602b
commit
d9fcfe3878
@ -179,7 +179,7 @@ services:
|
|||||||
VITE_SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
VITE_SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
||||||
STATIC_DIR: "/app/static"
|
STATIC_DIR: "/app/static"
|
||||||
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
||||||
ENABLE_BM25_SEARCH: false
|
ENABLE_BM25_SEARCH: true
|
||||||
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||||
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||||
secrets: [postgres_password, openai_key, anthropic_key]
|
secrets: [postgres_password, openai_key, anthropic_key]
|
||||||
|
|||||||
@ -4,5 +4,4 @@ python-jose==3.3.0
|
|||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
sqladmin==0.20.1
|
sqladmin==0.20.1
|
||||||
mcp==1.10.0
|
mcp==1.10.0
|
||||||
bm25s[full]==0.2.13
|
|
||||||
slowapi==0.1.9
|
slowapi==0.1.9
|
||||||
@ -4,7 +4,10 @@ Search endpoints for the knowledge base API.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from sqlalchemy.orm import load_only
|
from sqlalchemy.orm import load_only
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
@ -51,6 +54,76 @@ TITLE_MATCH_BOOST = 0.01
|
|||||||
# This gives a small boost to popular items without dominating relevance
|
# This gives a small boost to popular items without dominating relevance
|
||||||
POPULARITY_BOOST = 0.02
|
POPULARITY_BOOST = 0.02
|
||||||
|
|
||||||
|
# Recency boost settings
|
||||||
|
# Maximum bonus for brand new content (additive)
|
||||||
|
RECENCY_BOOST_MAX = 0.005
|
||||||
|
# Half-life in days: content loses half its recency boost every N days
|
||||||
|
RECENCY_HALF_LIFE_DAYS = 90
|
||||||
|
|
||||||
|
# Query expansion: map abbreviations/acronyms to full forms
|
||||||
|
# These help match when users search for "ML" but documents say "machine learning"
|
||||||
|
QUERY_EXPANSIONS: dict[str, list[str]] = {
|
||||||
|
# AI/ML abbreviations
|
||||||
|
"ai": ["artificial intelligence"],
|
||||||
|
"ml": ["machine learning"],
|
||||||
|
"dl": ["deep learning"],
|
||||||
|
"nlp": ["natural language processing"],
|
||||||
|
"cv": ["computer vision"],
|
||||||
|
"rl": ["reinforcement learning"],
|
||||||
|
"llm": ["large language model", "language model"],
|
||||||
|
"gpt": ["generative pretrained transformer", "language model"],
|
||||||
|
"nn": ["neural network"],
|
||||||
|
"cnn": ["convolutional neural network"],
|
||||||
|
"rnn": ["recurrent neural network"],
|
||||||
|
"lstm": ["long short term memory"],
|
||||||
|
"gan": ["generative adversarial network"],
|
||||||
|
"rag": ["retrieval augmented generation"],
|
||||||
|
# Rationality/EA terms
|
||||||
|
"ea": ["effective altruism"],
|
||||||
|
"lw": ["lesswrong", "less wrong"],
|
||||||
|
"gwwc": ["giving what we can"],
|
||||||
|
"agi": ["artificial general intelligence"],
|
||||||
|
"asi": ["artificial superintelligence"],
|
||||||
|
"fai": ["friendly ai", "ai alignment"],
|
||||||
|
"x-risk": ["existential risk"],
|
||||||
|
"xrisk": ["existential risk"],
|
||||||
|
"p(doom)": ["probability of doom", "ai risk"],
|
||||||
|
# Reverse mappings (full forms -> abbreviations)
|
||||||
|
"artificial intelligence": ["ai"],
|
||||||
|
"machine learning": ["ml"],
|
||||||
|
"deep learning": ["dl"],
|
||||||
|
"natural language processing": ["nlp"],
|
||||||
|
"computer vision": ["cv"],
|
||||||
|
"reinforcement learning": ["rl"],
|
||||||
|
"neural network": ["nn"],
|
||||||
|
"effective altruism": ["ea"],
|
||||||
|
"existential risk": ["x-risk", "xrisk"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def expand_query(query: str) -> str:
|
||||||
|
"""
|
||||||
|
Expand query with synonyms and abbreviations.
|
||||||
|
|
||||||
|
This helps match documents that use different terminology for the same concept.
|
||||||
|
For example, "ML algorithms" -> "ML machine learning algorithms"
|
||||||
|
"""
|
||||||
|
query_lower = query.lower()
|
||||||
|
expansions = []
|
||||||
|
|
||||||
|
for term, synonyms in QUERY_EXPANSIONS.items():
|
||||||
|
# Check if term appears as a whole word in the query
|
||||||
|
# Use word boundaries to avoid matching partial words
|
||||||
|
pattern = r'\b' + re.escape(term) + r'\b'
|
||||||
|
if re.search(pattern, query_lower):
|
||||||
|
expansions.extend(synonyms)
|
||||||
|
|
||||||
|
if expansions:
|
||||||
|
# Add expansions to the original query
|
||||||
|
return query + " " + " ".join(expansions)
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
# Common words to ignore when checking for query term presence
|
# Common words to ignore when checking for query term presence
|
||||||
STOPWORDS = {
|
STOPWORDS = {
|
||||||
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
|
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
|
||||||
@ -116,66 +189,78 @@ def deduplicate_by_source(chunks: list[Chunk]) -> list[Chunk]:
|
|||||||
return list(best_by_source.values())
|
return list(best_by_source.values())
|
||||||
|
|
||||||
|
|
||||||
def apply_title_boost(
|
def apply_source_boosts(
|
||||||
chunks: list[Chunk],
|
chunks: list[Chunk],
|
||||||
query_terms: set[str],
|
query_terms: set[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Boost chunks when query terms match the source title.
|
Apply title, popularity, and recency boosts to chunks in a single DB query.
|
||||||
|
|
||||||
Title matches are a strong signal since titles summarize content.
|
- Title boost: chunks get boosted when query terms appear in source title
|
||||||
"""
|
- Popularity boost: chunks get boosted based on source karma/popularity
|
||||||
if not query_terms or not chunks:
|
- Recency boost: newer content gets a small boost that decays over time
|
||||||
return
|
|
||||||
|
|
||||||
# Get unique source IDs
|
|
||||||
source_ids = list({chunk.source_id for chunk in chunks})
|
|
||||||
|
|
||||||
# Fetch full source items (polymorphic) to access title attribute
|
|
||||||
with make_session() as db:
|
|
||||||
sources = db.query(SourceItem).filter(
|
|
||||||
SourceItem.id.in_(source_ids)
|
|
||||||
).all()
|
|
||||||
titles = {s.id: (getattr(s, 'title', None) or "").lower() for s in sources}
|
|
||||||
|
|
||||||
# Apply boost to chunks whose source title matches query terms
|
|
||||||
for chunk in chunks:
|
|
||||||
title = titles.get(chunk.source_id, "")
|
|
||||||
if not title:
|
|
||||||
continue
|
|
||||||
|
|
||||||
matches = sum(1 for term in query_terms if term in title)
|
|
||||||
if matches > 0:
|
|
||||||
boost = TITLE_MATCH_BOOST * (matches / len(query_terms))
|
|
||||||
chunk.relevance_score = (chunk.relevance_score or 0) + boost
|
|
||||||
|
|
||||||
|
|
||||||
def apply_popularity_boost(chunks: list[Chunk]) -> None:
|
|
||||||
"""
|
|
||||||
Boost chunks based on source popularity.
|
|
||||||
|
|
||||||
Uses the popularity property from SourceItem subclasses.
|
|
||||||
ForumPost uses karma, others default to 1.0.
|
|
||||||
"""
|
"""
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return
|
return
|
||||||
|
|
||||||
source_ids = list({chunk.source_id for chunk in chunks})
|
source_ids = list({chunk.source_id for chunk in chunks})
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Single query to fetch all source metadata
|
||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
sources = db.query(SourceItem).filter(
|
sources = db.query(SourceItem).filter(
|
||||||
SourceItem.id.in_(source_ids)
|
SourceItem.id.in_(source_ids)
|
||||||
).all()
|
).all()
|
||||||
popularity_map = {s.id: s.popularity for s in sources}
|
source_map = {
|
||||||
|
s.id: {
|
||||||
|
"title": (getattr(s, "title", None) or "").lower(),
|
||||||
|
"popularity": s.popularity,
|
||||||
|
"inserted_at": s.inserted_at,
|
||||||
|
}
|
||||||
|
for s in sources
|
||||||
|
}
|
||||||
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
popularity = popularity_map.get(chunk.source_id, 1.0)
|
source_data = source_map.get(chunk.source_id, {})
|
||||||
|
score = chunk.relevance_score or 0
|
||||||
|
|
||||||
|
# Apply title boost if query terms match
|
||||||
|
if query_terms:
|
||||||
|
title = source_data.get("title", "")
|
||||||
|
if title:
|
||||||
|
matches = sum(1 for term in query_terms if term in title)
|
||||||
|
if matches > 0:
|
||||||
|
score += TITLE_MATCH_BOOST * (matches / len(query_terms))
|
||||||
|
|
||||||
|
# Apply popularity boost
|
||||||
|
popularity = source_data.get("popularity", 1.0)
|
||||||
if popularity != 1.0:
|
if popularity != 1.0:
|
||||||
# Apply boost: score * (1 + POPULARITY_BOOST * (popularity - 1))
|
|
||||||
# For popularity=2.0: multiplier = 1.02
|
|
||||||
# For popularity=0.5: multiplier = 0.99
|
|
||||||
multiplier = 1.0 + POPULARITY_BOOST * (popularity - 1.0)
|
multiplier = 1.0 + POPULARITY_BOOST * (popularity - 1.0)
|
||||||
chunk.relevance_score = (chunk.relevance_score or 0) * multiplier
|
score *= multiplier
|
||||||
|
|
||||||
|
# Apply recency boost (exponential decay with half-life)
|
||||||
|
inserted_at = source_data.get("inserted_at")
|
||||||
|
if inserted_at:
|
||||||
|
# Handle timezone-naive timestamps
|
||||||
|
if inserted_at.tzinfo is None:
|
||||||
|
inserted_at = inserted_at.replace(tzinfo=timezone.utc)
|
||||||
|
age_days = (now - inserted_at).total_seconds() / 86400
|
||||||
|
# Exponential decay: boost = max_boost * 0.5^(age/half_life)
|
||||||
|
decay = math.pow(0.5, age_days / RECENCY_HALF_LIFE_DAYS)
|
||||||
|
score += RECENCY_BOOST_MAX * decay
|
||||||
|
|
||||||
|
chunk.relevance_score = score
|
||||||
|
|
||||||
|
|
||||||
|
# Keep legacy functions for backwards compatibility and testing
|
||||||
|
def apply_title_boost(chunks: list[Chunk], query_terms: set[str]) -> None:
|
||||||
|
"""Legacy function - use apply_source_boosts instead."""
|
||||||
|
apply_source_boosts(chunks, query_terms)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_popularity_boost(chunks: list[Chunk]) -> None:
|
||||||
|
"""Legacy function - use apply_source_boosts instead."""
|
||||||
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
|
|
||||||
def fuse_scores_rrf(
|
def fuse_scores_rrf(
|
||||||
@ -242,12 +327,21 @@ async def search_chunks(
|
|||||||
# This helps find results that rank well in one method but not the other
|
# This helps find results that rank well in one method but not the other
|
||||||
internal_limit = limit * CANDIDATE_MULTIPLIER
|
internal_limit = limit * CANDIDATE_MULTIPLIER
|
||||||
|
|
||||||
# Extract query text for HyDE expansion
|
# Extract query text and apply synonym/abbreviation expansion
|
||||||
search_data = list(data) # Copy to avoid modifying original
|
|
||||||
if settings.ENABLE_HYDE_EXPANSION:
|
|
||||||
query_text = " ".join(
|
query_text = " ".join(
|
||||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||||
)
|
)
|
||||||
|
expanded_query = expand_query(query_text)
|
||||||
|
|
||||||
|
# If query was expanded, use expanded version for search
|
||||||
|
if expanded_query != query_text:
|
||||||
|
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
|
||||||
|
search_data = [extract.DataChunk(data=[expanded_query])]
|
||||||
|
else:
|
||||||
|
search_data = list(data) # Copy to avoid modifying original
|
||||||
|
|
||||||
|
# Apply HyDE expansion if enabled
|
||||||
|
if settings.ENABLE_HYDE_EXPANSION:
|
||||||
# Only expand queries with 4+ words (short queries are usually specific enough)
|
# Only expand queries with 4+ words (short queries are usually specific enough)
|
||||||
if len(query_text.split()) >= 4:
|
if len(query_text.split()) >= 4:
|
||||||
try:
|
try:
|
||||||
@ -316,15 +410,15 @@ async def search_chunks(
|
|||||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply query term presence boost and title boost
|
# Apply query term presence boost
|
||||||
if chunks and query_text.strip():
|
if chunks and query_text.strip():
|
||||||
query_terms = extract_query_terms(query_text)
|
query_terms = extract_query_terms(query_text)
|
||||||
apply_query_term_boost(chunks, query_terms)
|
apply_query_term_boost(chunks, query_terms)
|
||||||
apply_title_boost(chunks, query_terms)
|
# Apply title + popularity boosts (single DB query)
|
||||||
|
apply_source_boosts(chunks, query_terms)
|
||||||
# Apply popularity boost (karma-based for forum posts)
|
elif chunks:
|
||||||
if chunks:
|
# No query terms, just apply popularity boost
|
||||||
apply_popularity_boost(chunks)
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
# Rerank using cross-encoder for better precision
|
# Rerank using cross-encoder for better precision
|
||||||
if settings.ENABLE_RERANKING and chunks and query_text.strip():
|
if settings.ENABLE_RERANKING and chunks and query_text.strip():
|
||||||
|
|||||||
@ -178,6 +178,8 @@ ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
|||||||
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
ENABLE_SEARCH_SCORING = boolean_env("ENABLE_SEARCH_SCORING", True)
|
||||||
ENABLE_HYDE_EXPANSION = boolean_env("ENABLE_HYDE_EXPANSION", True)
|
ENABLE_HYDE_EXPANSION = boolean_env("ENABLE_HYDE_EXPANSION", True)
|
||||||
HYDE_TIMEOUT = float(os.getenv("HYDE_TIMEOUT", "3.0"))
|
HYDE_TIMEOUT = float(os.getenv("HYDE_TIMEOUT", "3.0"))
|
||||||
|
ENABLE_RERANKING = boolean_env("ENABLE_RERANKING", True)
|
||||||
|
RERANK_MODEL = os.getenv("RERANK_MODEL", "rerank-2-lite")
|
||||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||||
|
|
||||||
|
|||||||
@ -6,17 +6,23 @@ title boosting, and source deduplication.
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from memory.api.search.search import (
|
from memory.api.search.search import (
|
||||||
extract_query_terms,
|
extract_query_terms,
|
||||||
apply_query_term_boost,
|
apply_query_term_boost,
|
||||||
deduplicate_by_source,
|
deduplicate_by_source,
|
||||||
apply_title_boost,
|
apply_title_boost,
|
||||||
apply_popularity_boost,
|
apply_popularity_boost,
|
||||||
|
apply_source_boosts,
|
||||||
|
expand_query,
|
||||||
fuse_scores_rrf,
|
fuse_scores_rrf,
|
||||||
STOPWORDS,
|
STOPWORDS,
|
||||||
QUERY_TERM_BOOST,
|
QUERY_TERM_BOOST,
|
||||||
TITLE_MATCH_BOOST,
|
TITLE_MATCH_BOOST,
|
||||||
POPULARITY_BOOST,
|
POPULARITY_BOOST,
|
||||||
|
RECENCY_BOOST_MAX,
|
||||||
|
RECENCY_HALF_LIFE_DAYS,
|
||||||
RRF_K,
|
RRF_K,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -239,6 +245,8 @@ def test_apply_title_boost(mock_make_session, title, query_terms, initial_score,
|
|||||||
mock_source = MagicMock()
|
mock_source = MagicMock()
|
||||||
mock_source.id = 1
|
mock_source.id = 1
|
||||||
mock_source.title = title
|
mock_source.title = title
|
||||||
|
mock_source.popularity = 1.0 # Default popularity, no boost
|
||||||
|
mock_source.inserted_at = None # No recency boost
|
||||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
chunks = [_make_title_chunk(1, initial_score)]
|
chunks = [_make_title_chunk(1, initial_score)]
|
||||||
@ -268,6 +276,8 @@ def test_apply_title_boost_none_title(mock_make_session):
|
|||||||
mock_source = MagicMock()
|
mock_source = MagicMock()
|
||||||
mock_source.id = 1
|
mock_source.id = 1
|
||||||
mock_source.title = None
|
mock_source.title = None
|
||||||
|
mock_source.popularity = 1.0 # Default popularity, no boost
|
||||||
|
mock_source.inserted_at = None # No recency boost
|
||||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
chunks = [_make_title_chunk(1, 0.5)]
|
chunks = [_make_title_chunk(1, 0.5)]
|
||||||
@ -307,6 +317,7 @@ def test_apply_popularity_boost(mock_make_session, popularity, initial_score, ex
|
|||||||
mock_source = MagicMock()
|
mock_source = MagicMock()
|
||||||
mock_source.id = 1
|
mock_source.id = 1
|
||||||
mock_source.popularity = popularity
|
mock_source.popularity = popularity
|
||||||
|
mock_source.inserted_at = None # No recency boost
|
||||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
chunks = [_make_pop_chunk(1, initial_score)]
|
chunks = [_make_pop_chunk(1, initial_score)]
|
||||||
@ -331,9 +342,11 @@ def test_apply_popularity_boost_multiple_sources(mock_make_session):
|
|||||||
source1 = MagicMock()
|
source1 = MagicMock()
|
||||||
source1.id = 1
|
source1.id = 1
|
||||||
source1.popularity = 2.0 # High karma
|
source1.popularity = 2.0 # High karma
|
||||||
|
source1.inserted_at = None # No recency boost
|
||||||
source2 = MagicMock()
|
source2 = MagicMock()
|
||||||
source2.id = 2
|
source2.id = 2
|
||||||
source2.popularity = 1.0 # Default
|
source2.popularity = 1.0 # Default
|
||||||
|
source2.inserted_at = None # No recency boost
|
||||||
mock_session.query.return_value.filter.return_value.all.return_value = [source1, source2]
|
mock_session.query.return_value.filter.return_value.all.return_value = [source1, source2]
|
||||||
|
|
||||||
chunks = [_make_pop_chunk(1, 0.5), _make_pop_chunk(2, 0.5)]
|
chunks = [_make_pop_chunk(1, 0.5), _make_pop_chunk(2, 0.5)]
|
||||||
@ -423,3 +436,233 @@ def test_fuse_scores_rrf_only_ranks_matter():
|
|||||||
assert result1["a"] == pytest.approx(result2["a"])
|
assert result1["a"] == pytest.approx(result2["a"])
|
||||||
assert result1["b"] == pytest.approx(result2["b"])
|
assert result1["b"] == pytest.approx(result2["b"])
|
||||||
assert result1["c"] == pytest.approx(result2["c"])
|
assert result1["c"] == pytest.approx(result2["c"])
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# apply_source_boosts recency tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _make_recency_chunk(source_id: int, score: float = 0.5):
|
||||||
|
"""Create a mock chunk for recency boost tests."""
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.source_id = source_id
|
||||||
|
chunk.relevance_score = score
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.search.search.make_session")
|
||||||
|
def test_recency_boost_new_content(mock_make_session):
|
||||||
|
"""Brand new content should get full recency boost."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.id = 1
|
||||||
|
mock_source.title = None
|
||||||
|
mock_source.popularity = 1.0
|
||||||
|
mock_source.inserted_at = now # Just inserted
|
||||||
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
|
chunks = [_make_recency_chunk(1, 0.5)]
|
||||||
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
|
# Should get nearly full recency boost
|
||||||
|
expected = 0.5 + RECENCY_BOOST_MAX
|
||||||
|
assert chunks[0].relevance_score == pytest.approx(expected, rel=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.search.search.make_session")
|
||||||
|
def test_recency_boost_half_life_decay(mock_make_session):
|
||||||
|
"""Content at half-life age should get half the boost."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.id = 1
|
||||||
|
mock_source.title = None
|
||||||
|
mock_source.popularity = 1.0
|
||||||
|
mock_source.inserted_at = now - timedelta(days=RECENCY_HALF_LIFE_DAYS)
|
||||||
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
|
chunks = [_make_recency_chunk(1, 0.5)]
|
||||||
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
|
# Should get half the recency boost
|
||||||
|
expected = 0.5 + RECENCY_BOOST_MAX * 0.5
|
||||||
|
assert chunks[0].relevance_score == pytest.approx(expected, rel=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.search.search.make_session")
|
||||||
|
def test_recency_boost_old_content(mock_make_session):
|
||||||
|
"""Very old content should get minimal recency boost."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.id = 1
|
||||||
|
mock_source.title = None
|
||||||
|
mock_source.popularity = 1.0
|
||||||
|
mock_source.inserted_at = now - timedelta(days=365) # 1 year old
|
||||||
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
|
chunks = [_make_recency_chunk(1, 0.5)]
|
||||||
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
|
# Should get very little boost (about 0.5^4 ≈ 0.0625 of max)
|
||||||
|
assert chunks[0].relevance_score > 0.5
|
||||||
|
assert chunks[0].relevance_score < 0.5 + RECENCY_BOOST_MAX * 0.1
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.search.search.make_session")
|
||||||
|
def test_recency_boost_none_timestamp(mock_make_session):
|
||||||
|
"""Should handle None inserted_at gracefully."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.id = 1
|
||||||
|
mock_source.title = None
|
||||||
|
mock_source.popularity = 1.0
|
||||||
|
mock_source.inserted_at = None
|
||||||
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
|
chunks = [_make_recency_chunk(1, 0.5)]
|
||||||
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
|
# No recency boost applied
|
||||||
|
assert chunks[0].relevance_score == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.search.search.make_session")
|
||||||
|
def test_recency_boost_timezone_naive(mock_make_session):
|
||||||
|
"""Should handle timezone-naive timestamps."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
# Timezone-naive timestamp
|
||||||
|
naive_dt = datetime.now().replace(tzinfo=None)
|
||||||
|
mock_source = MagicMock()
|
||||||
|
mock_source.id = 1
|
||||||
|
mock_source.title = None
|
||||||
|
mock_source.popularity = 1.0
|
||||||
|
mock_source.inserted_at = naive_dt
|
||||||
|
mock_session.query.return_value.filter.return_value.all.return_value = [mock_source]
|
||||||
|
|
||||||
|
chunks = [_make_recency_chunk(1, 0.5)]
|
||||||
|
apply_source_boosts(chunks, set()) # Should not raise
|
||||||
|
|
||||||
|
# Should get nearly full boost since it's very recent
|
||||||
|
assert chunks[0].relevance_score > 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@patch("memory.api.search.search.make_session")
|
||||||
|
def test_recency_boost_ordering(mock_make_session):
|
||||||
|
"""Newer content should rank higher than older content."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_make_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_make_session.return_value.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
source_new = MagicMock()
|
||||||
|
source_new.id = 1
|
||||||
|
source_new.title = None
|
||||||
|
source_new.popularity = 1.0
|
||||||
|
source_new.inserted_at = now - timedelta(days=1)
|
||||||
|
|
||||||
|
source_old = MagicMock()
|
||||||
|
source_old.id = 2
|
||||||
|
source_old.title = None
|
||||||
|
source_old.popularity = 1.0
|
||||||
|
source_old.inserted_at = now - timedelta(days=180)
|
||||||
|
|
||||||
|
mock_session.query.return_value.filter.return_value.all.return_value = [source_new, source_old]
|
||||||
|
|
||||||
|
chunks = [_make_recency_chunk(1, 0.5), _make_recency_chunk(2, 0.5)]
|
||||||
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
|
# Newer content should have higher score
|
||||||
|
assert chunks[0].relevance_score > chunks[1].relevance_score
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# expand_query tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,expected_expansion",
|
||||||
|
[
|
||||||
|
# AI/ML abbreviations
|
||||||
|
("ML algorithms", "artificial intelligence"), # Not "ML" -> won't have AI
|
||||||
|
("AI safety research", "artificial intelligence"),
|
||||||
|
("NLP models for text", "natural language processing"),
|
||||||
|
("deep learning vs DL", "deep learning"),
|
||||||
|
# Rationality/EA terms
|
||||||
|
("EA organizations", "effective altruism"),
|
||||||
|
("AGI timeline predictions", "artificial general intelligence"),
|
||||||
|
("x-risk reduction", "existential risk"),
|
||||||
|
# Reverse mappings
|
||||||
|
("machine learning basics", "ml"),
|
||||||
|
("artificial intelligence ethics", "ai"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_expand_query_adds_synonyms(query, expected_expansion):
|
||||||
|
"""Should expand abbreviations and add synonyms."""
|
||||||
|
result = expand_query(query)
|
||||||
|
assert expected_expansion in result.lower()
|
||||||
|
assert query in result # Original query preserved
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query",
|
||||||
|
[
|
||||||
|
"hello world", # No expansions
|
||||||
|
"python programming", # No expansions
|
||||||
|
"database optimization", # No expansions
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_expand_query_no_match(query):
|
||||||
|
"""Should return original query when no expansions match."""
|
||||||
|
result = expand_query(query)
|
||||||
|
assert result == query
|
||||||
|
|
||||||
|
|
||||||
|
def test_expand_query_case_insensitive():
|
||||||
|
"""Should match terms regardless of case."""
|
||||||
|
assert "artificial intelligence" in expand_query("AI research").lower()
|
||||||
|
assert "artificial intelligence" in expand_query("ai research").lower()
|
||||||
|
assert "artificial intelligence" in expand_query("Ai Research").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_expand_query_word_boundaries():
|
||||||
|
"""Should only match whole words, not partial matches."""
|
||||||
|
# "mail" contains "ai" but shouldn't trigger expansion
|
||||||
|
result = expand_query("email server")
|
||||||
|
assert result == "email server"
|
||||||
|
|
||||||
|
# "claim" contains "ai" but shouldn't trigger expansion
|
||||||
|
result = expand_query("insurance claim")
|
||||||
|
assert result == "insurance claim"
|
||||||
|
|
||||||
|
|
||||||
|
def test_expand_query_multiple_terms():
|
||||||
|
"""Should expand multiple matching terms."""
|
||||||
|
result = expand_query("AI and ML applications")
|
||||||
|
assert "artificial intelligence" in result.lower()
|
||||||
|
assert "machine learning" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_expand_query_preserves_original():
|
||||||
|
"""Should preserve original query text."""
|
||||||
|
original = "AI safety research"
|
||||||
|
result = expand_query(original)
|
||||||
|
assert result.startswith(original)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user