mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 17:22:58 +01:00
Refactor search: add LLM query analysis, extract constants
- Add query_analysis.py for LLM-based query preprocessing
- Detects modalities from natural language ("on lesswrong" -> forum)
- Cleans meta-language ("I remember reading..." -> core query)
- Generates query variants for better recall
- Dynamically discovers modalities and domains from database
- Extract constants to constants.py
- STOPWORDS, RRF_K, boost values, etc.
- Cleaner separation of configuration from logic
- Refactor search_chunks into focused helper functions
- _run_llm_analysis: parallel query analysis + HyDE
- _apply_query_analysis: apply analysis results
- _build_search_data: construct search data with variants
- _run_searches: embedding + BM25 with RRF fusion
- _fetch_chunks: database retrieval with scoring
- _apply_boosts: title, popularity, recency boosts
- _apply_reranking: cross-encoder reranking
- Remove redundant regex-based modality detection
- Remove static QUERY_EXPANSIONS (LLM handles this better)
- Add comprehensive tests for query_analysis module
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
60e6e18284
commit
782b56939f
75
src/memory/api/search/constants.py
Normal file
75
src/memory/api/search/constants.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
Constants for search functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Reciprocal Rank Fusion constant (k parameter)
|
||||||
|
# Higher values reduce the influence of top-ranked documents
|
||||||
|
# 60 is the standard value from the original RRF paper
|
||||||
|
RRF_K = 60
|
||||||
|
|
||||||
|
# Multiplier for internal search limit before fusion
|
||||||
|
# We search for more candidates than requested, fuse scores, then return top N
|
||||||
|
# This helps find results that rank well in one method but not the other
|
||||||
|
CANDIDATE_MULTIPLIER = 5
|
||||||
|
|
||||||
|
# How many candidates to pass to reranker (multiplier of final limit)
|
||||||
|
# Higher = more accurate but slower and more expensive
|
||||||
|
RERANK_CANDIDATE_MULTIPLIER = 3
|
||||||
|
|
||||||
|
# Bonus for chunks containing query terms (added to RRF score)
|
||||||
|
QUERY_TERM_BOOST = 0.005
|
||||||
|
|
||||||
|
# Bonus when query terms match the source title (stronger signal)
|
||||||
|
TITLE_MATCH_BOOST = 0.01
|
||||||
|
|
||||||
|
# Bonus multiplier for popularity (applied as: score * (1 + POPULARITY_BOOST * (popularity - 1)))
|
||||||
|
# This gives a small boost to popular items without dominating relevance
|
||||||
|
POPULARITY_BOOST = 0.02
|
||||||
|
|
||||||
|
# Recency boost settings
|
||||||
|
# Maximum bonus for brand new content (additive)
|
||||||
|
RECENCY_BOOST_MAX = 0.005
|
||||||
|
# Half-life in days: content loses half its recency boost every N days
|
||||||
|
RECENCY_HALF_LIFE_DAYS = 90
|
||||||
|
|
||||||
|
# Common words to ignore when checking for query term presence
|
||||||
|
STOPWORDS = frozenset({
|
||||||
|
# Articles
|
||||||
|
"a", "an", "the",
|
||||||
|
# Be verbs
|
||||||
|
"is", "are", "was", "were", "be", "been", "being",
|
||||||
|
# Have verbs
|
||||||
|
"have", "has", "had",
|
||||||
|
# Do verbs
|
||||||
|
"do", "does", "did",
|
||||||
|
# Modal verbs
|
||||||
|
"will", "would", "could", "should", "may", "might", "must", "shall", "can",
|
||||||
|
"need", "dare", "ought", "used",
|
||||||
|
# Prepositions
|
||||||
|
"to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into",
|
||||||
|
"through", "during", "before", "after", "above", "below", "between", "under",
|
||||||
|
# Adverbs
|
||||||
|
"again", "further", "then", "once", "here", "there", "when", "where", "why", "how",
|
||||||
|
# Quantifiers
|
||||||
|
"all", "each", "few", "more", "most", "other", "some", "such",
|
||||||
|
# Negation
|
||||||
|
"no", "nor", "not",
|
||||||
|
# Other common words
|
||||||
|
"only", "own", "same", "so", "than", "too", "very", "just",
|
||||||
|
# Conjunctions
|
||||||
|
"and", "but", "if", "or", "because", "until", "while", "although", "though",
|
||||||
|
# Relative pronouns
|
||||||
|
"what", "which", "who", "whom",
|
||||||
|
# Demonstratives
|
||||||
|
"this", "that", "these", "those",
|
||||||
|
# Personal pronouns
|
||||||
|
"i", "me", "my", "myself",
|
||||||
|
"we", "our", "ours", "ourselves",
|
||||||
|
"you", "your", "yours", "yourself", "yourselves",
|
||||||
|
"he", "him", "his", "himself",
|
||||||
|
"she", "her", "hers", "herself",
|
||||||
|
"it", "its", "itself",
|
||||||
|
"they", "them", "their", "theirs", "themselves",
|
||||||
|
# Misc common words
|
||||||
|
"about", "get", "got", "getting", "like", "also",
|
||||||
|
})
|
||||||
327
src/memory/api/search/query_analysis.py
Normal file
327
src/memory/api/search/query_analysis.py
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
"""
|
||||||
|
LLM-based query analysis for intelligent search preprocessing.
|
||||||
|
|
||||||
|
Uses a fast LLM (Haiku) to analyze natural language queries and extract:
|
||||||
|
- Modalities: content types to search (forum, book, comic, etc.)
|
||||||
|
- Source hints: author names, domains, or specific sources
|
||||||
|
- Cleaned query: the actual search terms with meta-language removed
|
||||||
|
- Query variants: alternative phrasings to search
|
||||||
|
|
||||||
|
This runs in parallel with HyDE for maximum efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import textwrap
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from sqlalchemy import func, text
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import SourceItem
|
||||||
|
from memory.common.llms import create_provider, LLMSettings, Message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Threshold for listing specific sources (if fewer than N distinct domains, list them)
|
||||||
|
MAX_DOMAINS_TO_LIST = 10
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModalityInfo:
|
||||||
|
"""Information about a modality from the database."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
count: int
|
||||||
|
domains: list[str] = field(default_factory=list)
|
||||||
|
source_count: int | None = None # For modalities with parent entities (e.g., books)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
"""Build description including domains if there are few enough."""
|
||||||
|
if self.source_count:
|
||||||
|
base = f"{self.source_count:,} sources ({self.count:,} sections)"
|
||||||
|
else:
|
||||||
|
base = f"{self.count:,} items"
|
||||||
|
|
||||||
|
if self.domains:
|
||||||
|
return f"{base} from: {', '.join(self.domains)}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
# Cache for database-derived information
|
||||||
|
_modality_cache: dict[str, ModalityInfo] = {}
|
||||||
|
_cache_timestamp: float = 0
|
||||||
|
_CACHE_TTL_SECONDS = 3600 # Refresh every hour
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tables_with_url_column(db) -> list[str]:
|
||||||
|
"""Query database schema to find tables that have a 'url' column."""
|
||||||
|
result = db.execute(text("""
|
||||||
|
SELECT table_name FROM information_schema.columns
|
||||||
|
WHERE column_name = 'url' AND table_schema = 'public'
|
||||||
|
"""))
|
||||||
|
return [row[0] for row in result]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_modality_domains(db) -> dict[str, list[str]]:
|
||||||
|
"""Get domains for each modality that has URL data."""
|
||||||
|
tables = _get_tables_with_url_column(db)
|
||||||
|
if not tables:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Build a UNION query to get modality + url from all URL-containing tables
|
||||||
|
union_parts = []
|
||||||
|
for table in tables:
|
||||||
|
union_parts.append(
|
||||||
|
f"SELECT s.modality, t.url FROM source_item s "
|
||||||
|
f"JOIN {table} t ON s.id = t.id WHERE t.url IS NOT NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not union_parts:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
query = " UNION ALL ".join(union_parts)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = db.execute(text(query))
|
||||||
|
rows = list(result)
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.debug(f"Database error getting modality URLs: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Group URLs by modality and extract domains
|
||||||
|
modality_domains: dict[str, set[str]] = {}
|
||||||
|
for modality, url in rows:
|
||||||
|
if not modality or not url:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
domain = urlparse(url).netloc
|
||||||
|
if domain:
|
||||||
|
domain = domain.replace("www.", "")
|
||||||
|
if modality not in modality_domains:
|
||||||
|
modality_domains[modality] = set()
|
||||||
|
modality_domains[modality].add(domain)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only return domains for modalities with few enough to list
|
||||||
|
return {
|
||||||
|
modality: sorted(domains)
|
||||||
|
for modality, domains in modality_domains.items()
|
||||||
|
if len(domains) <= MAX_DOMAINS_TO_LIST
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_source_counts(db) -> dict[str, int]:
|
||||||
|
"""Get distinct source counts for modalities with parent entities."""
|
||||||
|
try:
|
||||||
|
# Books: count distinct book_id from book_section
|
||||||
|
result = db.execute(text(
|
||||||
|
"SELECT COUNT(DISTINCT book_id) FROM book_section"
|
||||||
|
))
|
||||||
|
book_count = result.scalar() or 0
|
||||||
|
return {"book": book_count}
|
||||||
|
except SQLAlchemyError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_modality_cache() -> None:
|
||||||
|
"""Query database to find modalities with actual content."""
|
||||||
|
global _modality_cache, _cache_timestamp
|
||||||
|
|
||||||
|
try:
|
||||||
|
with make_session() as db:
|
||||||
|
# Get modality counts
|
||||||
|
results = (
|
||||||
|
db.query(SourceItem.modality, func.count(SourceItem.id))
|
||||||
|
.group_by(SourceItem.modality)
|
||||||
|
.order_by(func.count(SourceItem.id).desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get domains for modalities with URLs (single query)
|
||||||
|
modality_domains = _get_modality_domains(db)
|
||||||
|
|
||||||
|
# Get source counts for modalities with parent entities
|
||||||
|
source_counts = _get_source_counts(db)
|
||||||
|
|
||||||
|
_modality_cache = {}
|
||||||
|
for modality, count in results:
|
||||||
|
if modality and count > 0:
|
||||||
|
_modality_cache[modality] = ModalityInfo(
|
||||||
|
name=modality,
|
||||||
|
count=count,
|
||||||
|
domains=modality_domains.get(modality, []),
|
||||||
|
source_count=source_counts.get(modality),
|
||||||
|
)
|
||||||
|
|
||||||
|
_cache_timestamp = time.time()
|
||||||
|
logger.debug(f"Refreshed modality cache: {list(_modality_cache.keys())}")
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.warning(f"Database error refreshing modality cache: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_available_modalities() -> dict[str, ModalityInfo]:
|
||||||
|
"""Get modalities with content, refreshing cache if needed."""
|
||||||
|
global _cache_timestamp
|
||||||
|
|
||||||
|
if time.time() - _cache_timestamp > _CACHE_TTL_SECONDS or not _modality_cache:
|
||||||
|
_refresh_modality_cache()
|
||||||
|
|
||||||
|
return _modality_cache
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prompt() -> str:
|
||||||
|
"""Build the query analysis prompt with actual available modalities."""
|
||||||
|
modalities = _get_available_modalities()
|
||||||
|
|
||||||
|
if not modalities:
|
||||||
|
modality_section = " (no content indexed yet)"
|
||||||
|
else:
|
||||||
|
lines = []
|
||||||
|
for info in modalities.values():
|
||||||
|
lines.append(f" - {info.name}: {info.description}")
|
||||||
|
modality_section = "\n".join(lines)
|
||||||
|
|
||||||
|
modality_names = list(modalities.keys()) if modalities else []
|
||||||
|
|
||||||
|
return (
|
||||||
|
textwrap.dedent("""
|
||||||
|
Analyze this search query and extract structured information.
|
||||||
|
|
||||||
|
The user is searching a personal knowledge base containing:
|
||||||
|
{modality_section}
|
||||||
|
|
||||||
|
Return a JSON object:
|
||||||
|
{{
|
||||||
|
"modalities": [], // From: {modality_names} (empty = search all)
|
||||||
|
"sources": [], // Specific sources/authors mentioned
|
||||||
|
"cleaned_query": "", // Query with meta-language removed
|
||||||
|
"query_variants": [] // 1-3 alternative phrasings
|
||||||
|
}}
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
- "on lesswrong" -> forum, "comic about" -> comic, etc.
|
||||||
|
- Remove "there was something about", "I remember reading", etc.
|
||||||
|
- Generate useful query variants
|
||||||
|
|
||||||
|
Return ONLY valid JSON.
|
||||||
|
""")
|
||||||
|
.strip()
|
||||||
|
.format(
|
||||||
|
modality_section=modality_section,
|
||||||
|
modality_names=modality_names,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueryAnalysis:
|
||||||
|
"""Result of LLM-based query analysis."""
|
||||||
|
|
||||||
|
modalities: set[str] = field(default_factory=set)
|
||||||
|
sources: list[str] = field(default_factory=list)
|
||||||
|
cleaned_query: str = ""
|
||||||
|
query_variants: list[str] = field(default_factory=list)
|
||||||
|
success: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# Cache for recent analyses
|
||||||
|
_analysis_cache: dict[str, QueryAnalysis] = {}
|
||||||
|
_CACHE_MAX_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
|
async def analyze_query(
|
||||||
|
query: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
timeout: float = 3.0,
|
||||||
|
) -> QueryAnalysis:
|
||||||
|
"""
|
||||||
|
Analyze a search query using an LLM to extract search parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The user's natural language search query
|
||||||
|
model: LLM model to use (defaults to SUMMARIZER_MODEL, ideally Haiku)
|
||||||
|
timeout: Maximum time to wait for LLM response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryAnalysis with extracted modalities, sources, cleaned query, and variants
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
cache_key = query.lower().strip()
|
||||||
|
if cache_key in _analysis_cache:
|
||||||
|
logger.debug(f"Query analysis cache hit for: {query[:50]}...")
|
||||||
|
return _analysis_cache[cache_key]
|
||||||
|
|
||||||
|
result = QueryAnalysis(cleaned_query=query)
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = create_provider(model=model or settings.SUMMARIZER_MODEL)
|
||||||
|
|
||||||
|
messages = [Message.user(text=f"Query: {query}")]
|
||||||
|
|
||||||
|
llm_settings = LLMSettings(
|
||||||
|
temperature=0.1, # Low temperature for consistent structured output
|
||||||
|
max_tokens=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
provider.agenerate(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=_build_prompt(),
|
||||||
|
settings=llm_settings,
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response:
|
||||||
|
# Parse JSON response
|
||||||
|
response = response.strip()
|
||||||
|
# Handle markdown code blocks
|
||||||
|
if response.startswith("```"):
|
||||||
|
response = response.split("```")[1]
|
||||||
|
if response.startswith("json"):
|
||||||
|
response = response[4:]
|
||||||
|
response = response.strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(response)
|
||||||
|
|
||||||
|
result.modalities = set(data.get("modalities", []))
|
||||||
|
result.sources = data.get("sources", [])
|
||||||
|
result.cleaned_query = data.get("cleaned_query", query)
|
||||||
|
result.query_variants = data.get("query_variants", [])
|
||||||
|
result.success = True
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Query analysis: '{query[:40]}...' -> "
|
||||||
|
f"modalities={result.modalities}, "
|
||||||
|
f"cleaned='{result.cleaned_query[:30]}...'"
|
||||||
|
)
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Failed to parse query analysis JSON: {e}")
|
||||||
|
result.cleaned_query = query
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
if len(_analysis_cache) >= _CACHE_MAX_SIZE:
|
||||||
|
keys_to_remove = list(_analysis_cache.keys())[: _CACHE_MAX_SIZE // 2]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del _analysis_cache[key]
|
||||||
|
_analysis_cache[cache_key] = result
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Query analysis timed out for: {query[:50]}...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Query analysis failed: {e}")
|
||||||
|
|
||||||
|
return result
|
||||||
@ -5,10 +5,8 @@ Search endpoints for the knowledge base API.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import re
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
|
||||||
from sqlalchemy.orm import load_only
|
from sqlalchemy.orm import load_only
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
@ -16,6 +14,17 @@ from memory.common.db.models import Chunk, SourceItem
|
|||||||
from memory.common.collections import ALL_COLLECTIONS
|
from memory.common.collections import ALL_COLLECTIONS
|
||||||
from memory.api.search.embeddings import search_chunks_embeddings
|
from memory.api.search.embeddings import search_chunks_embeddings
|
||||||
from memory.api.search import scorer
|
from memory.api.search import scorer
|
||||||
|
from memory.api.search.constants import (
|
||||||
|
RRF_K,
|
||||||
|
CANDIDATE_MULTIPLIER,
|
||||||
|
RERANK_CANDIDATE_MULTIPLIER,
|
||||||
|
QUERY_TERM_BOOST,
|
||||||
|
TITLE_MATCH_BOOST,
|
||||||
|
POPULARITY_BOOST,
|
||||||
|
RECENCY_BOOST_MAX,
|
||||||
|
RECENCY_HALF_LIFE_DAYS,
|
||||||
|
STOPWORDS,
|
||||||
|
)
|
||||||
|
|
||||||
if settings.ENABLE_BM25_SEARCH:
|
if settings.ENABLE_BM25_SEARCH:
|
||||||
from memory.api.search.bm25 import search_bm25_chunks
|
from memory.api.search.bm25 import search_bm25_chunks
|
||||||
@ -26,6 +35,7 @@ if settings.ENABLE_HYDE_EXPANSION:
|
|||||||
if settings.ENABLE_RERANKING:
|
if settings.ENABLE_RERANKING:
|
||||||
from memory.api.search.rerank import rerank_chunks
|
from memory.api.search.rerank import rerank_chunks
|
||||||
|
|
||||||
|
from memory.api.search.query_analysis import analyze_query, QueryAnalysis
|
||||||
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||||
|
|
||||||
# Default config for when none is provided
|
# Default config for when none is provided
|
||||||
@ -33,202 +43,6 @@ _DEFAULT_CONFIG = SearchConfig()
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Reciprocal Rank Fusion constant (k parameter)
|
|
||||||
# Higher values reduce the influence of top-ranked documents
|
|
||||||
# 60 is the standard value from the original RRF paper
|
|
||||||
RRF_K = 60
|
|
||||||
|
|
||||||
# Multiplier for internal search limit before fusion
|
|
||||||
# We search for more candidates than requested, fuse scores, then return top N
|
|
||||||
# This helps find results that rank well in one method but not the other
|
|
||||||
CANDIDATE_MULTIPLIER = 5
|
|
||||||
|
|
||||||
# How many candidates to pass to reranker (multiplier of final limit)
|
|
||||||
# Higher = more accurate but slower and more expensive
|
|
||||||
RERANK_CANDIDATE_MULTIPLIER = 3
|
|
||||||
|
|
||||||
# Bonus for chunks containing query terms (added to RRF score)
|
|
||||||
QUERY_TERM_BOOST = 0.005
|
|
||||||
|
|
||||||
# Bonus when query terms match the source title (stronger signal)
|
|
||||||
TITLE_MATCH_BOOST = 0.01
|
|
||||||
|
|
||||||
# Bonus multiplier for popularity (applied as: score * (1 + POPULARITY_BOOST * (popularity - 1)))
|
|
||||||
# This gives a small boost to popular items without dominating relevance
|
|
||||||
POPULARITY_BOOST = 0.02
|
|
||||||
|
|
||||||
# Recency boost settings
|
|
||||||
# Maximum bonus for brand new content (additive)
|
|
||||||
RECENCY_BOOST_MAX = 0.005
|
|
||||||
# Half-life in days: content loses half its recency boost every N days
|
|
||||||
RECENCY_HALF_LIFE_DAYS = 90
|
|
||||||
|
|
||||||
# Query expansion: map abbreviations/acronyms to full forms
|
|
||||||
# These help match when users search for "ML" but documents say "machine learning"
|
|
||||||
QUERY_EXPANSIONS: dict[str, list[str]] = {
|
|
||||||
# AI/ML abbreviations
|
|
||||||
"ai": ["artificial intelligence"],
|
|
||||||
"ml": ["machine learning"],
|
|
||||||
"dl": ["deep learning"],
|
|
||||||
"nlp": ["natural language processing"],
|
|
||||||
"cv": ["computer vision"],
|
|
||||||
"rl": ["reinforcement learning"],
|
|
||||||
"llm": ["large language model", "language model"],
|
|
||||||
"gpt": ["generative pretrained transformer", "language model"],
|
|
||||||
"nn": ["neural network"],
|
|
||||||
"cnn": ["convolutional neural network"],
|
|
||||||
"rnn": ["recurrent neural network"],
|
|
||||||
"lstm": ["long short term memory"],
|
|
||||||
"gan": ["generative adversarial network"],
|
|
||||||
"rag": ["retrieval augmented generation"],
|
|
||||||
# Rationality/EA terms
|
|
||||||
"ea": ["effective altruism"],
|
|
||||||
"lw": ["lesswrong", "less wrong"],
|
|
||||||
"gwwc": ["giving what we can"],
|
|
||||||
"agi": ["artificial general intelligence"],
|
|
||||||
"asi": ["artificial superintelligence"],
|
|
||||||
"fai": ["friendly ai", "ai alignment"],
|
|
||||||
"x-risk": ["existential risk"],
|
|
||||||
"xrisk": ["existential risk"],
|
|
||||||
"p(doom)": ["probability of doom", "ai risk"],
|
|
||||||
# Reverse mappings (full forms -> abbreviations)
|
|
||||||
"artificial intelligence": ["ai"],
|
|
||||||
"machine learning": ["ml"],
|
|
||||||
"deep learning": ["dl"],
|
|
||||||
"natural language processing": ["nlp"],
|
|
||||||
"computer vision": ["cv"],
|
|
||||||
"reinforcement learning": ["rl"],
|
|
||||||
"neural network": ["nn"],
|
|
||||||
"effective altruism": ["ea"],
|
|
||||||
"existential risk": ["x-risk", "xrisk"],
|
|
||||||
# Family relationships (bidirectional)
|
|
||||||
"father": ["son", "daughter", "child", "parent", "dad"],
|
|
||||||
"mother": ["son", "daughter", "child", "parent", "mom"],
|
|
||||||
"parent": ["child", "son", "daughter", "father", "mother"],
|
|
||||||
"son": ["father", "parent", "child"],
|
|
||||||
"daughter": ["mother", "parent", "child"],
|
|
||||||
"child": ["parent", "father", "mother"],
|
|
||||||
"dad": ["father", "son", "daughter", "child"],
|
|
||||||
"mom": ["mother", "son", "daughter", "child"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Modality detection patterns: map query phrases to collection names
|
|
||||||
# Each entry is (pattern, modalities, strip_pattern)
|
|
||||||
# - pattern: regex to match in query
|
|
||||||
# - modalities: set of collection names to filter to
|
|
||||||
# - strip_pattern: whether to remove the matched text from query
|
|
||||||
MODALITY_PATTERNS: list[tuple[str, set[str], bool]] = [
|
|
||||||
# Comics
|
|
||||||
(r"\b(comic|comics|webcomic|webcomics)\b", {"comic"}, True),
|
|
||||||
# Forum posts (LessWrong, EA Forum, etc.)
|
|
||||||
(r"\b(on\s+)?(lesswrong|lw|less\s+wrong)\b", {"forum"}, True),
|
|
||||||
(r"\b(on\s+)?(ea\s+forum|effective\s+altruism\s+forum)\b", {"forum"}, True),
|
|
||||||
(r"\b(on\s+)?(alignment\s+forum|af)\b", {"forum"}, True),
|
|
||||||
(r"\b(forum\s+post|lw\s+post|post\s+on)\b", {"forum"}, True),
|
|
||||||
# Books
|
|
||||||
(r"\b(in\s+a\s+book|in\s+the\s+book|book|chapter)\b", {"book"}, True),
|
|
||||||
# Blog posts / articles
|
|
||||||
(r"\b(blog\s+post|blog|article)\b", {"blog"}, True),
|
|
||||||
# Email
|
|
||||||
(r"\b(email|e-mail|mail)\b", {"mail"}, True),
|
|
||||||
# Photos / images
|
|
||||||
(r"\b(photo|photograph|picture|image)\b", {"photo"}, True),
|
|
||||||
# Documents
|
|
||||||
(r"\b(document|pdf|doc)\b", {"doc"}, True),
|
|
||||||
# Chat / messages
|
|
||||||
(r"\b(chat|message|discord|slack)\b", {"chat"}, True),
|
|
||||||
# Git
|
|
||||||
(r"\b(commit|git|pull\s+request|pr)\b", {"git"}, True),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Meta-language patterns to strip (these don't indicate modality, just noise)
|
|
||||||
META_LANGUAGE_PATTERNS: list[str] = [
|
|
||||||
r"\bthere\s+was\s+(something|some|some\s+\w+|an?\s+\w+)\s+(about|on)\b",
|
|
||||||
r"\bi\s+remember\s+(reading|seeing|there\s+being)\s*(an?\s+)?",
|
|
||||||
r"\bi\s+(read|saw|found)\s+(something|an?\s+\w+)\s+about\b",
|
|
||||||
r"\bsomething\s+about\b",
|
|
||||||
r"\bsome\s+about\b",
|
|
||||||
r"\bthis\s+whole\s+\w+\s+thing\b",
|
|
||||||
r"\bthat\s+\w+\s+thing\b",
|
|
||||||
r"\bthat\s+about\b", # Clean up leftover "that about"
|
|
||||||
r"\ba\s+about\b", # Clean up leftover "a about"
|
|
||||||
r"\bthe\s+about\b", # Clean up leftover "the about"
|
|
||||||
r"\bthere\s+was\s+some\s+about\b", # Clean up leftover
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def detect_modality_hints(query: str) -> tuple[str, set[str]]:
|
|
||||||
"""
|
|
||||||
Detect content type hints in query and extract modalities.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(cleaned_query, detected_modalities)
|
|
||||||
- cleaned_query: query with modality indicators and meta-language removed
|
|
||||||
- detected_modalities: set of collection names detected from query
|
|
||||||
"""
|
|
||||||
query_lower = query.lower()
|
|
||||||
detected: set[str] = set()
|
|
||||||
cleaned = query
|
|
||||||
|
|
||||||
# First, detect and strip modality patterns
|
|
||||||
for pattern, modalities, strip in MODALITY_PATTERNS:
|
|
||||||
if re.search(pattern, query_lower, re.IGNORECASE):
|
|
||||||
detected.update(modalities)
|
|
||||||
if strip:
|
|
||||||
cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE)
|
|
||||||
|
|
||||||
# Strip meta-language patterns (regardless of modality detection)
|
|
||||||
for pattern in META_LANGUAGE_PATTERNS:
|
|
||||||
cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE)
|
|
||||||
|
|
||||||
# Clean up whitespace
|
|
||||||
cleaned = " ".join(cleaned.split())
|
|
||||||
|
|
||||||
return cleaned, detected
|
|
||||||
|
|
||||||
|
|
||||||
def expand_query(query: str) -> str:
|
|
||||||
"""
|
|
||||||
Expand query with synonyms and abbreviations.
|
|
||||||
|
|
||||||
This helps match documents that use different terminology for the same concept.
|
|
||||||
For example, "ML algorithms" -> "ML machine learning algorithms"
|
|
||||||
"""
|
|
||||||
query_lower = query.lower()
|
|
||||||
expansions = []
|
|
||||||
|
|
||||||
for term, synonyms in QUERY_EXPANSIONS.items():
|
|
||||||
# Check if term appears as a whole word in the query
|
|
||||||
# Use word boundaries to avoid matching partial words
|
|
||||||
pattern = r'\b' + re.escape(term) + r'\b'
|
|
||||||
if re.search(pattern, query_lower):
|
|
||||||
expansions.extend(synonyms)
|
|
||||||
|
|
||||||
if expansions:
|
|
||||||
# Add expansions to the original query
|
|
||||||
return query + " " + " ".join(expansions)
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
# Common words to ignore when checking for query term presence
|
|
||||||
STOPWORDS = {
|
|
||||||
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
|
|
||||||
"have", "has", "had", "do", "does", "did", "will", "would", "could",
|
|
||||||
"should", "may", "might", "must", "shall", "can", "need", "dare",
|
|
||||||
"ought", "used", "to", "of", "in", "for", "on", "with", "at", "by",
|
|
||||||
"from", "as", "into", "through", "during", "before", "after", "above",
|
|
||||||
"below", "between", "under", "again", "further", "then", "once", "here",
|
|
||||||
"there", "when", "where", "why", "how", "all", "each", "few", "more",
|
|
||||||
"most", "other", "some", "such", "no", "nor", "not", "only", "own",
|
|
||||||
"same", "so", "than", "too", "very", "just", "and", "but", "if", "or",
|
|
||||||
"because", "until", "while", "although", "though", "after", "before",
|
|
||||||
"what", "which", "who", "whom", "this", "that", "these", "those", "i",
|
|
||||||
"me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your",
|
|
||||||
"yours", "yourself", "yourselves", "he", "him", "his", "himself", "she",
|
|
||||||
"her", "hers", "herself", "it", "its", "itself", "they", "them", "their",
|
|
||||||
"theirs", "themselves", "about", "get", "got", "getting", "like", "also",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def extract_query_terms(query: str) -> set[str]:
|
def extract_query_terms(query: str) -> set[str]:
|
||||||
"""Extract meaningful terms from query, filtering stopwords."""
|
"""Extract meaningful terms from query, filtering stopwords."""
|
||||||
@ -270,7 +84,9 @@ def deduplicate_by_source(chunks: list[Chunk]) -> list[Chunk]:
|
|||||||
source_id = chunk.source_id
|
source_id = chunk.source_id
|
||||||
if source_id not in best_by_source:
|
if source_id not in best_by_source:
|
||||||
best_by_source[source_id] = chunk
|
best_by_source[source_id] = chunk
|
||||||
elif (chunk.relevance_score or 0) > (best_by_source[source_id].relevance_score or 0):
|
elif (chunk.relevance_score or 0) > (
|
||||||
|
best_by_source[source_id].relevance_score or 0
|
||||||
|
):
|
||||||
best_by_source[source_id] = chunk
|
best_by_source[source_id] = chunk
|
||||||
return list(best_by_source.values())
|
return list(best_by_source.values())
|
||||||
|
|
||||||
@ -294,9 +110,7 @@ def apply_source_boosts(
|
|||||||
|
|
||||||
# Single query to fetch all source metadata
|
# Single query to fetch all source metadata
|
||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
sources = db.query(SourceItem).filter(
|
sources = db.query(SourceItem).filter(SourceItem.id.in_(source_ids)).all()
|
||||||
SourceItem.id.in_(source_ids)
|
|
||||||
).all()
|
|
||||||
source_map = {
|
source_map = {
|
||||||
s.id: {
|
s.id: {
|
||||||
"title": (getattr(s, "title", None) or "").lower(),
|
"title": (getattr(s, "title", None) or "").lower(),
|
||||||
@ -369,7 +183,9 @@ def fuse_scores_rrf(
|
|||||||
Dict mapping chunk IDs to RRF scores
|
Dict mapping chunk IDs to RRF scores
|
||||||
"""
|
"""
|
||||||
# Convert scores to ranks (1-indexed)
|
# Convert scores to ranks (1-indexed)
|
||||||
emb_ranked = sorted(embedding_scores.keys(), key=lambda x: embedding_scores[x], reverse=True)
|
emb_ranked = sorted(
|
||||||
|
embedding_scores.keys(), key=lambda x: embedding_scores[x], reverse=True
|
||||||
|
)
|
||||||
bm25_ranked = sorted(bm25_scores.keys(), key=lambda x: bm25_scores[x], reverse=True)
|
bm25_ranked = sorted(bm25_scores.keys(), key=lambda x: bm25_scores[x], reverse=True)
|
||||||
|
|
||||||
emb_ranks = {chunk_id: rank + 1 for rank, chunk_id in enumerate(emb_ranked)}
|
emb_ranks = {chunk_id: rank + 1 for rank, chunk_id in enumerate(emb_ranked)}
|
||||||
@ -393,84 +209,131 @@ def fuse_scores_rrf(
|
|||||||
return fused
|
return fused
|
||||||
|
|
||||||
|
|
||||||
async def search_chunks(
|
async def _run_llm_analysis(
|
||||||
data: list[extract.DataChunk],
|
query_text: str,
|
||||||
modalities: set[str] = set(),
|
use_query_analysis: bool,
|
||||||
limit: int = 10,
|
use_hyde: bool,
|
||||||
filters: SearchFilters = {},
|
) -> tuple[QueryAnalysis | None, str | None]:
|
||||||
timeout: int = 2,
|
|
||||||
config: SearchConfig = _DEFAULT_CONFIG,
|
|
||||||
) -> list[Chunk]:
|
|
||||||
"""
|
"""
|
||||||
Search chunks using embedding similarity and optionally BM25.
|
Run LLM-based query analysis and/or HyDE expansion in parallel.
|
||||||
|
|
||||||
Combines results using weighted score fusion, giving bonus to documents
|
Returns:
|
||||||
that match both semantically and lexically.
|
(analysis_result, hyde_doc) tuple
|
||||||
|
|
||||||
If HyDE is enabled, also generates a hypothetical document from the query
|
|
||||||
and includes it in the embedding search for better semantic matching.
|
|
||||||
|
|
||||||
Enhancement flags in config override global settings when set:
|
|
||||||
- useBm25: Enable BM25 lexical search
|
|
||||||
- useHyde: Enable HyDE query expansion
|
|
||||||
- useReranking: Enable cross-encoder reranking
|
|
||||||
- useQueryExpansion: Enable synonym/abbreviation expansion
|
|
||||||
- useModalityDetection: Detect content type hints from query
|
|
||||||
"""
|
"""
|
||||||
# Resolve enhancement flags: config overrides global settings
|
analysis_result: QueryAnalysis | None = None
|
||||||
use_bm25 = config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
|
hyde_doc: str | None = None
|
||||||
use_hyde = config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
|
|
||||||
use_reranking = config.useReranking if config.useReranking is not None else settings.ENABLE_RERANKING
|
|
||||||
use_query_expansion = config.useQueryExpansion if config.useQueryExpansion is not None else True
|
|
||||||
use_modality_detection = config.useModalityDetection if config.useModalityDetection is not None else False
|
|
||||||
|
|
||||||
# Search for more candidates than requested, fuse scores, then return top N
|
if not (use_query_analysis or use_hyde):
|
||||||
# This helps find results that rank well in one method but not the other
|
return analysis_result, hyde_doc
|
||||||
internal_limit = limit * CANDIDATE_MULTIPLIER
|
|
||||||
|
|
||||||
# Extract query text
|
tasks = []
|
||||||
query_text = " ".join(
|
|
||||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
if use_query_analysis:
|
||||||
|
tasks.append(("analysis", analyze_query(query_text, timeout=3.0)))
|
||||||
|
|
||||||
|
if use_hyde and len(query_text.split()) >= 4:
|
||||||
|
tasks.append(
|
||||||
|
("hyde", expand_query_hyde(query_text, timeout=settings.HYDE_TIMEOUT))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Detect modality hints and clean query if enabled
|
if not tasks:
|
||||||
if use_modality_detection:
|
return analysis_result, hyde_doc
|
||||||
cleaned_query, detected_modalities = detect_modality_hints(query_text)
|
|
||||||
if detected_modalities:
|
|
||||||
# Override passed modalities with detected ones
|
|
||||||
modalities = detected_modalities
|
|
||||||
logger.debug(f"Modality detection: '{query_text[:50]}...' -> modalities={detected_modalities}")
|
|
||||||
if cleaned_query != query_text:
|
|
||||||
logger.debug(f"Query cleaning: '{query_text[:50]}...' -> '{cleaned_query[:50]}...'")
|
|
||||||
query_text = cleaned_query
|
|
||||||
# Update data with cleaned query for downstream processing
|
|
||||||
data = [extract.DataChunk(data=[cleaned_query])]
|
|
||||||
|
|
||||||
if use_query_expansion:
|
try:
|
||||||
expanded_query = expand_query(query_text)
|
results = await asyncio.gather(
|
||||||
# If query was expanded, use expanded version for search
|
*[task for _, task in tasks], return_exceptions=True
|
||||||
if expanded_query != query_text:
|
)
|
||||||
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
|
|
||||||
search_data = [extract.DataChunk(data=[expanded_query])]
|
for i, (name, _) in enumerate(tasks):
|
||||||
else:
|
result = results[i]
|
||||||
search_data = list(data) # Copy to avoid modifying original
|
if isinstance(result, Exception):
|
||||||
else:
|
logger.warning(f"{name} failed: {result}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name == "analysis" and result:
|
||||||
|
analysis_result = result
|
||||||
|
elif name == "hyde" and result:
|
||||||
|
hyde_doc = result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Parallel LLM calls failed: {e}")
|
||||||
|
|
||||||
|
return analysis_result, hyde_doc
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_query_analysis(
|
||||||
|
analysis_result: QueryAnalysis,
|
||||||
|
query_text: str,
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
modalities: set[str],
|
||||||
|
) -> tuple[str, list[extract.DataChunk], set[str], list[str]]:
|
||||||
|
"""
|
||||||
|
Apply query analysis results to modify query, data, and modalities.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(updated_query_text, updated_data, updated_modalities, query_variants)
|
||||||
|
"""
|
||||||
|
query_variants: list[str] = []
|
||||||
|
|
||||||
|
if not (analysis_result and analysis_result.success):
|
||||||
|
return query_text, data, modalities, query_variants
|
||||||
|
|
||||||
|
# Use detected modalities if any
|
||||||
|
if analysis_result.modalities:
|
||||||
|
modalities = analysis_result.modalities
|
||||||
|
logger.debug(f"Query analysis modalities: {modalities}")
|
||||||
|
|
||||||
|
# Use cleaned query
|
||||||
|
if analysis_result.cleaned_query and analysis_result.cleaned_query != query_text:
|
||||||
|
logger.debug(
|
||||||
|
f"Query analysis cleaning: '{query_text[:40]}...' -> '{analysis_result.cleaned_query[:40]}...'"
|
||||||
|
)
|
||||||
|
query_text = analysis_result.cleaned_query
|
||||||
|
data = [extract.DataChunk(data=[analysis_result.cleaned_query])]
|
||||||
|
|
||||||
|
# Collect query variants
|
||||||
|
query_variants.extend(analysis_result.query_variants)
|
||||||
|
|
||||||
|
return query_text, data, modalities, query_variants
|
||||||
|
|
||||||
|
|
||||||
|
def _build_search_data(
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
hyde_doc: str | None,
|
||||||
|
query_variants: list[str],
|
||||||
|
query_text: str,
|
||||||
|
) -> list[extract.DataChunk]:
|
||||||
|
"""
|
||||||
|
Build the list of data chunks to search with.
|
||||||
|
|
||||||
|
Includes original query, HyDE expansion, and query variants.
|
||||||
|
"""
|
||||||
search_data = list(data)
|
search_data = list(data)
|
||||||
|
|
||||||
# Apply HyDE expansion if enabled
|
# Add HyDE expansion if we got one
|
||||||
if use_hyde:
|
|
||||||
# Only expand queries with 4+ words (short queries are usually specific enough)
|
|
||||||
if len(query_text.split()) >= 4:
|
|
||||||
try:
|
|
||||||
hyde_doc = await expand_query_hyde(
|
|
||||||
query_text, timeout=settings.HYDE_TIMEOUT
|
|
||||||
)
|
|
||||||
if hyde_doc:
|
if hyde_doc:
|
||||||
logger.debug(f"HyDE expansion: '{query_text[:30]}...' -> '{hyde_doc[:50]}...'")
|
logger.debug(f"HyDE expansion: '{query_text[:30]}...' -> '{hyde_doc[:50]}...'")
|
||||||
search_data.append(extract.DataChunk(data=[hyde_doc]))
|
search_data.append(extract.DataChunk(data=[hyde_doc]))
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"HyDE expansion failed, using original query: {e}")
|
|
||||||
|
|
||||||
|
# Add query variants from analysis (limit to 3)
|
||||||
|
for variant in query_variants[:3]:
|
||||||
|
search_data.append(extract.DataChunk(data=[variant]))
|
||||||
|
|
||||||
|
return search_data
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_searches(
|
||||||
|
search_data: list[extract.DataChunk],
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
modalities: set[str],
|
||||||
|
internal_limit: int,
|
||||||
|
filters: SearchFilters,
|
||||||
|
timeout: int,
|
||||||
|
use_bm25: bool,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""
|
||||||
|
Run embedding and optionally BM25 searches, returning fused scores.
|
||||||
|
"""
|
||||||
# Run embedding search
|
# Run embedding search
|
||||||
embedding_scores = await search_chunks_embeddings(
|
embedding_scores = await search_chunks_embeddings(
|
||||||
search_data, modalities, internal_limit, filters, timeout
|
search_data, modalities, internal_limit, filters, timeout
|
||||||
@ -487,14 +350,25 @@ async def search_chunks(
|
|||||||
logger.warning("BM25 search timed out, using embedding results only")
|
logger.warning("BM25 search timed out, using embedding results only")
|
||||||
|
|
||||||
# Fuse scores from both methods using Reciprocal Rank Fusion
|
# Fuse scores from both methods using Reciprocal Rank Fusion
|
||||||
fused_scores = fuse_scores_rrf(embedding_scores, bm25_scores)
|
return fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_chunks(
|
||||||
|
fused_scores: dict[str, float],
|
||||||
|
limit: int,
|
||||||
|
use_reranking: bool,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Fetch chunk objects from database and set their relevance scores.
|
||||||
|
"""
|
||||||
if not fused_scores:
|
if not fused_scores:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Sort by score and take top results
|
# Sort by score and take top results
|
||||||
# If reranking is enabled, fetch more candidates for the reranker to work with
|
# If reranking is enabled, fetch more candidates for the reranker to work with
|
||||||
sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
|
sorted_ids = sorted(
|
||||||
|
fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True
|
||||||
|
)
|
||||||
if use_reranking:
|
if use_reranking:
|
||||||
fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER
|
fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER
|
||||||
else:
|
else:
|
||||||
@ -522,29 +396,125 @@ async def search_chunks(
|
|||||||
|
|
||||||
db.expunge_all()
|
db.expunge_all()
|
||||||
|
|
||||||
# Extract query text for boosting and reranking
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_boosts(
|
||||||
|
chunks: list[Chunk],
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Apply query term, title, popularity, and recency boosts to chunks.
|
||||||
|
"""
|
||||||
|
if not chunks:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract query text for boosting
|
||||||
query_text = " ".join(
|
query_text = " ".join(
|
||||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply query term presence boost
|
if query_text.strip():
|
||||||
if chunks and query_text.strip():
|
|
||||||
query_terms = extract_query_terms(query_text)
|
query_terms = extract_query_terms(query_text)
|
||||||
apply_query_term_boost(chunks, query_terms)
|
apply_query_term_boost(chunks, query_terms)
|
||||||
# Apply title + popularity boosts (single DB query)
|
# Apply title + popularity boosts (single DB query)
|
||||||
apply_source_boosts(chunks, query_terms)
|
apply_source_boosts(chunks, query_terms)
|
||||||
elif chunks:
|
else:
|
||||||
# No query terms, just apply popularity boost
|
# No query terms, just apply popularity boost
|
||||||
apply_source_boosts(chunks, set())
|
apply_source_boosts(chunks, set())
|
||||||
|
|
||||||
# Rerank using cross-encoder for better precision
|
|
||||||
if use_reranking and chunks and query_text.strip():
|
async def _apply_reranking(
|
||||||
|
chunks: list[Chunk],
|
||||||
|
query_text: str,
|
||||||
|
limit: int,
|
||||||
|
use_reranking: bool,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Apply cross-encoder reranking if enabled.
|
||||||
|
"""
|
||||||
|
if not (use_reranking and chunks and query_text.strip()):
|
||||||
|
return chunks
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunks = await rerank_chunks(
|
return await rerank_chunks(
|
||||||
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
|
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Reranking failed, using RRF order: {e}")
|
logger.warning(f"Reranking failed, using RRF order: {e}")
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def search_chunks(
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
modalities: set[str] = set(),
|
||||||
|
limit: int = 10,
|
||||||
|
filters: SearchFilters = {},
|
||||||
|
timeout: int = 2,
|
||||||
|
config: SearchConfig = _DEFAULT_CONFIG,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""
|
||||||
|
Search chunks using embedding similarity and optionally BM25.
|
||||||
|
|
||||||
|
Combines results using weighted score fusion, giving bonus to documents
|
||||||
|
that match both semantically and lexically.
|
||||||
|
|
||||||
|
If HyDE is enabled, also generates a hypothetical document from the query
|
||||||
|
and includes it in the embedding search for better semantic matching.
|
||||||
|
|
||||||
|
Enhancement flags in config override global settings when set:
|
||||||
|
- useBm25: Enable BM25 lexical search
|
||||||
|
- useHyde: Enable HyDE query expansion
|
||||||
|
- useReranking: Enable cross-encoder reranking
|
||||||
|
- useQueryAnalysis: LLM-based query analysis (extracts modalities, cleans query, generates variants)
|
||||||
|
"""
|
||||||
|
# Resolve enhancement flags: config overrides global settings
|
||||||
|
use_bm25 = (
|
||||||
|
config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
|
||||||
|
)
|
||||||
|
use_hyde = (
|
||||||
|
config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
|
||||||
|
)
|
||||||
|
use_reranking = (
|
||||||
|
config.useReranking
|
||||||
|
if config.useReranking is not None
|
||||||
|
else settings.ENABLE_RERANKING
|
||||||
|
)
|
||||||
|
use_query_analysis = (
|
||||||
|
config.useQueryAnalysis if config.useQueryAnalysis is not None else False
|
||||||
|
)
|
||||||
|
|
||||||
|
internal_limit = limit * CANDIDATE_MULTIPLIER
|
||||||
|
|
||||||
|
# Extract query text
|
||||||
|
query_text = " ".join(c for chunk in data for c in chunk.data if isinstance(c, str))
|
||||||
|
|
||||||
|
# Run LLM-based operations in parallel (query analysis + HyDE)
|
||||||
|
analysis_result, hyde_doc = await _run_llm_analysis(
|
||||||
|
query_text, use_query_analysis, use_hyde
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply query analysis results
|
||||||
|
query_text, data, modalities, query_variants = _apply_query_analysis(
|
||||||
|
analysis_result, query_text, data, modalities
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build search data with HyDE and variants
|
||||||
|
search_data = _build_search_data(data, hyde_doc, query_variants, query_text)
|
||||||
|
|
||||||
|
# Run searches and fuse scores
|
||||||
|
fused_scores = await _run_searches(
|
||||||
|
search_data, data, modalities, internal_limit, filters, timeout, use_bm25
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch chunks from database
|
||||||
|
chunks = _fetch_chunks(fused_scores, limit, use_reranking)
|
||||||
|
|
||||||
|
# Apply various boosts
|
||||||
|
_apply_boosts(chunks, data)
|
||||||
|
|
||||||
|
# Apply reranking if enabled
|
||||||
|
chunks = await _apply_reranking(chunks, query_text, limit, use_reranking)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|||||||
@ -87,8 +87,7 @@ class SearchConfig(BaseModel):
|
|||||||
useBm25: Optional[bool] = None
|
useBm25: Optional[bool] = None
|
||||||
useHyde: Optional[bool] = None
|
useHyde: Optional[bool] = None
|
||||||
useReranking: Optional[bool] = None
|
useReranking: Optional[bool] = None
|
||||||
useQueryExpansion: Optional[bool] = None
|
useQueryAnalysis: Optional[bool] = None # LLM-based query analysis (Haiku)
|
||||||
useModalityDetection: Optional[bool] = None
|
|
||||||
|
|
||||||
def model_post_init(self, __context) -> None:
|
def model_post_init(self, __context) -> None:
|
||||||
# Enforce reasonable limits
|
# Enforce reasonable limits
|
||||||
|
|||||||
428
tests/memory/api/search/test_query_analysis.py
Normal file
428
tests/memory/api/search/test_query_analysis.py
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
"""Tests for query_analysis module."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, Mock, AsyncMock
|
||||||
|
|
||||||
|
from memory.api.search.query_analysis import (
|
||||||
|
ModalityInfo,
|
||||||
|
QueryAnalysis,
|
||||||
|
_build_prompt,
|
||||||
|
_get_available_modalities,
|
||||||
|
analyze_query,
|
||||||
|
MAX_DOMAINS_TO_LIST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModalityInfo:
|
||||||
|
"""Tests for ModalityInfo dataclass."""
|
||||||
|
|
||||||
|
def test_description_items_only(self):
|
||||||
|
"""Basic description with just count."""
|
||||||
|
info = ModalityInfo(name="forum", count=1000)
|
||||||
|
assert info.description == "1,000 items"
|
||||||
|
|
||||||
|
def test_description_with_domains(self):
|
||||||
|
"""Description includes domains when present."""
|
||||||
|
info = ModalityInfo(
|
||||||
|
name="forum", count=1000, domains=["lesswrong.com", "example.com"]
|
||||||
|
)
|
||||||
|
assert info.description == "1,000 items from: lesswrong.com, example.com"
|
||||||
|
|
||||||
|
def test_description_with_source_count(self):
|
||||||
|
"""Description shows sources and sections for parent entities."""
|
||||||
|
info = ModalityInfo(name="book", count=22000, source_count=400)
|
||||||
|
assert info.description == "400 sources (22,000 sections)"
|
||||||
|
|
||||||
|
def test_description_with_source_count_and_domains(self):
|
||||||
|
"""Description shows sources, sections, and domains."""
|
||||||
|
info = ModalityInfo(
|
||||||
|
name="comic",
|
||||||
|
count=8000,
|
||||||
|
source_count=50,
|
||||||
|
domains=["xkcd.com", "smbc-comics.com"],
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
info.description
|
||||||
|
== "50 sources (8,000 sections) from: xkcd.com, smbc-comics.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_description_formats_large_numbers(self):
|
||||||
|
"""Numbers are formatted with commas."""
|
||||||
|
info = ModalityInfo(name="book", count=1234567, source_count=9876)
|
||||||
|
assert info.description == "9,876 sources (1,234,567 sections)"
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryAnalysis:
|
||||||
|
"""Tests for QueryAnalysis dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""QueryAnalysis has sensible defaults."""
|
||||||
|
result = QueryAnalysis()
|
||||||
|
assert result.modalities == set()
|
||||||
|
assert result.sources == []
|
||||||
|
assert result.cleaned_query == ""
|
||||||
|
assert result.query_variants == []
|
||||||
|
assert result.success is False
|
||||||
|
|
||||||
|
def test_with_values(self):
|
||||||
|
"""QueryAnalysis stores provided values."""
|
||||||
|
result = QueryAnalysis(
|
||||||
|
modalities={"forum", "book"},
|
||||||
|
sources=["lesswrong.com"],
|
||||||
|
cleaned_query="test query",
|
||||||
|
query_variants=["alternative query"],
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
assert result.modalities == {"forum", "book"}
|
||||||
|
assert result.sources == ["lesswrong.com"]
|
||||||
|
assert result.cleaned_query == "test query"
|
||||||
|
assert result.query_variants == ["alternative query"]
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildPrompt:
|
||||||
|
"""Tests for _build_prompt function."""
|
||||||
|
|
||||||
|
def test_builds_prompt_with_modalities(self):
|
||||||
|
"""Prompt includes modality information."""
|
||||||
|
mock_modalities = {
|
||||||
|
"forum": ModalityInfo(
|
||||||
|
name="forum", count=20000, domains=["lesswrong.com"]
|
||||||
|
),
|
||||||
|
"book": ModalityInfo(name="book", count=22000, source_count=400),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value=mock_modalities,
|
||||||
|
):
|
||||||
|
prompt = _build_prompt()
|
||||||
|
|
||||||
|
assert "forum" in prompt
|
||||||
|
assert "book" in prompt
|
||||||
|
assert "20,000 items from: lesswrong.com" in prompt
|
||||||
|
assert "400 sources (22,000 sections)" in prompt
|
||||||
|
assert "modalities" in prompt
|
||||||
|
assert "cleaned_query" in prompt
|
||||||
|
|
||||||
|
def test_builds_prompt_empty_modalities(self):
|
||||||
|
"""Prompt handles no modalities gracefully."""
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={},
|
||||||
|
):
|
||||||
|
prompt = _build_prompt()
|
||||||
|
|
||||||
|
assert "no content indexed yet" in prompt
|
||||||
|
|
||||||
|
def test_prompt_contains_json_structure(self):
|
||||||
|
"""Prompt includes expected JSON structure."""
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={"test": ModalityInfo(name="test", count=100)},
|
||||||
|
):
|
||||||
|
prompt = _build_prompt()
|
||||||
|
|
||||||
|
assert '"modalities": []' in prompt
|
||||||
|
assert '"sources": []' in prompt
|
||||||
|
assert '"cleaned_query": ""' in prompt
|
||||||
|
assert '"query_variants": []' in prompt
|
||||||
|
|
||||||
|
def test_prompt_contains_guidelines(self):
|
||||||
|
"""Prompt includes usage guidelines."""
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={"forum": ModalityInfo(name="forum", count=100)},
|
||||||
|
):
|
||||||
|
prompt = _build_prompt()
|
||||||
|
|
||||||
|
assert "lesswrong" in prompt.lower()
|
||||||
|
assert "comic" in prompt.lower()
|
||||||
|
assert "Remove" in prompt
|
||||||
|
assert "Return ONLY valid JSON" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyzeQuery:
|
||||||
|
"""Tests for analyze_query async function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_analysis_on_success(self):
|
||||||
|
"""Successfully parses LLM JSON response."""
|
||||||
|
mock_response = json.dumps(
|
||||||
|
{
|
||||||
|
"modalities": ["forum"],
|
||||||
|
"sources": ["lesswrong.com"],
|
||||||
|
"cleaned_query": "rationality concepts",
|
||||||
|
"query_variants": ["rational thinking", "epistemic rationality"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.agenerate = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis.create_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={"forum": ModalityInfo(name="forum", count=100)},
|
||||||
|
):
|
||||||
|
# Clear any cached results
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
query_analysis._analysis_cache = {}
|
||||||
|
|
||||||
|
result = await analyze_query("something on lesswrong about rationality")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.modalities == {"forum"}
|
||||||
|
assert result.sources == ["lesswrong.com"]
|
||||||
|
assert result.cleaned_query == "rationality concepts"
|
||||||
|
assert "rational thinking" in result.query_variants
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_markdown_code_blocks(self):
|
||||||
|
"""Strips markdown code blocks from response."""
|
||||||
|
mock_response = """```json
|
||||||
|
{
|
||||||
|
"modalities": ["book"],
|
||||||
|
"sources": [],
|
||||||
|
"cleaned_query": "test query",
|
||||||
|
"query_variants": []
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.agenerate = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis.create_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={"book": ModalityInfo(name="book", count=100)},
|
||||||
|
):
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
query_analysis._analysis_cache = {}
|
||||||
|
|
||||||
|
result = await analyze_query("find something in a book")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.modalities == {"book"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_invalid_json(self):
|
||||||
|
"""Returns default result on invalid JSON."""
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.agenerate = AsyncMock(return_value="not valid json {{{")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis.create_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={},
|
||||||
|
):
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
query_analysis._analysis_cache = {}
|
||||||
|
|
||||||
|
result = await analyze_query("test query")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.cleaned_query == "test query"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_timeout(self):
|
||||||
|
"""Returns default result on timeout."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def slow_response(*args, **kwargs):
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
return "{}"
|
||||||
|
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.agenerate = slow_response
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis.create_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={},
|
||||||
|
):
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
query_analysis._analysis_cache = {}
|
||||||
|
|
||||||
|
result = await analyze_query("test query", timeout=0.01)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.cleaned_query == "test query"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_caches_results(self):
|
||||||
|
"""Caches analysis results for repeated queries."""
|
||||||
|
mock_response = json.dumps(
|
||||||
|
{
|
||||||
|
"modalities": [],
|
||||||
|
"sources": [],
|
||||||
|
"cleaned_query": "cached query",
|
||||||
|
"query_variants": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.agenerate = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis.create_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={},
|
||||||
|
):
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
query_analysis._analysis_cache = {}
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = await analyze_query("test query for caching")
|
||||||
|
# Second call (should use cache)
|
||||||
|
result2 = await analyze_query("test query for caching")
|
||||||
|
|
||||||
|
# Provider should only be called once
|
||||||
|
assert mock_provider.agenerate.call_count == 1
|
||||||
|
assert result1.cleaned_query == result2.cleaned_query
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_case_insensitive(self):
|
||||||
|
"""Cache key is case-insensitive."""
|
||||||
|
mock_response = json.dumps(
|
||||||
|
{
|
||||||
|
"modalities": [],
|
||||||
|
"sources": [],
|
||||||
|
"cleaned_query": "test",
|
||||||
|
"query_variants": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.agenerate = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis.create_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={},
|
||||||
|
):
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
query_analysis._analysis_cache = {}
|
||||||
|
|
||||||
|
await analyze_query("Test Query")
|
||||||
|
await analyze_query("test query")
|
||||||
|
await analyze_query("TEST QUERY")
|
||||||
|
|
||||||
|
# All variations should hit the same cache entry
|
||||||
|
assert mock_provider.agenerate.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_empty_response(self):
|
||||||
|
"""Handles empty LLM response gracefully."""
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.agenerate = AsyncMock(return_value="")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis.create_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={},
|
||||||
|
):
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
query_analysis._analysis_cache = {}
|
||||||
|
|
||||||
|
result = await analyze_query("test query")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.cleaned_query == "test query"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_none_response(self):
|
||||||
|
"""Handles None LLM response gracefully."""
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.agenerate = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis.create_provider",
|
||||||
|
return_value=mock_provider,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._get_available_modalities",
|
||||||
|
return_value={},
|
||||||
|
):
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
query_analysis._analysis_cache = {}
|
||||||
|
|
||||||
|
result = await analyze_query("test query")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAvailableModalities:
|
||||||
|
"""Tests for _get_available_modalities function."""
|
||||||
|
|
||||||
|
def test_uses_cache_when_fresh(self):
|
||||||
|
"""Returns cached data when cache is fresh."""
|
||||||
|
import time
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
|
||||||
|
# Set up cache
|
||||||
|
query_analysis._modality_cache = {
|
||||||
|
"test": ModalityInfo(name="test", count=100)
|
||||||
|
}
|
||||||
|
query_analysis._cache_timestamp = time.time()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._refresh_modality_cache"
|
||||||
|
) as mock_refresh:
|
||||||
|
result = _get_available_modalities()
|
||||||
|
|
||||||
|
mock_refresh.assert_not_called()
|
||||||
|
assert "test" in result
|
||||||
|
|
||||||
|
def test_refreshes_cache_when_stale(self):
|
||||||
|
"""Refreshes cache when TTL has expired."""
|
||||||
|
import time
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
|
||||||
|
# Set up stale cache
|
||||||
|
query_analysis._modality_cache = {}
|
||||||
|
query_analysis._cache_timestamp = time.time() - 7200 # 2 hours ago
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._refresh_modality_cache"
|
||||||
|
) as mock_refresh:
|
||||||
|
_get_available_modalities()
|
||||||
|
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
|
||||||
|
def test_refreshes_cache_when_empty(self):
|
||||||
|
"""Refreshes cache when cache is empty."""
|
||||||
|
import time
|
||||||
|
from memory.api.search import query_analysis
|
||||||
|
|
||||||
|
query_analysis._modality_cache = {}
|
||||||
|
query_analysis._cache_timestamp = time.time()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"memory.api.search.query_analysis._refresh_modality_cache"
|
||||||
|
) as mock_refresh:
|
||||||
|
_get_available_modalities()
|
||||||
|
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
@ -15,8 +15,9 @@ from memory.api.search.search import (
|
|||||||
apply_title_boost,
|
apply_title_boost,
|
||||||
apply_popularity_boost,
|
apply_popularity_boost,
|
||||||
apply_source_boosts,
|
apply_source_boosts,
|
||||||
expand_query,
|
|
||||||
fuse_scores_rrf,
|
fuse_scores_rrf,
|
||||||
|
)
|
||||||
|
from memory.api.search.constants import (
|
||||||
STOPWORDS,
|
STOPWORDS,
|
||||||
QUERY_TERM_BOOST,
|
QUERY_TERM_BOOST,
|
||||||
TITLE_MATCH_BOOST,
|
TITLE_MATCH_BOOST,
|
||||||
@ -591,78 +592,3 @@ def test_recency_boost_ordering(mock_make_session):
|
|||||||
|
|
||||||
# Newer content should have higher score
|
# Newer content should have higher score
|
||||||
assert chunks[0].relevance_score > chunks[1].relevance_score
|
assert chunks[0].relevance_score > chunks[1].relevance_score
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# expand_query tests
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"query,expected_expansion",
|
|
||||||
[
|
|
||||||
# AI/ML abbreviations
|
|
||||||
("ML algorithms", "artificial intelligence"), # Not "ML" -> won't have AI
|
|
||||||
("AI safety research", "artificial intelligence"),
|
|
||||||
("NLP models for text", "natural language processing"),
|
|
||||||
("deep learning vs DL", "deep learning"),
|
|
||||||
# Rationality/EA terms
|
|
||||||
("EA organizations", "effective altruism"),
|
|
||||||
("AGI timeline predictions", "artificial general intelligence"),
|
|
||||||
("x-risk reduction", "existential risk"),
|
|
||||||
# Reverse mappings
|
|
||||||
("machine learning basics", "ml"),
|
|
||||||
("artificial intelligence ethics", "ai"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_expand_query_adds_synonyms(query, expected_expansion):
|
|
||||||
"""Should expand abbreviations and add synonyms."""
|
|
||||||
result = expand_query(query)
|
|
||||||
assert expected_expansion in result.lower()
|
|
||||||
assert query in result # Original query preserved
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"query",
|
|
||||||
[
|
|
||||||
"hello world", # No expansions
|
|
||||||
"python programming", # No expansions
|
|
||||||
"database optimization", # No expansions
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_expand_query_no_match(query):
|
|
||||||
"""Should return original query when no expansions match."""
|
|
||||||
result = expand_query(query)
|
|
||||||
assert result == query
|
|
||||||
|
|
||||||
|
|
||||||
def test_expand_query_case_insensitive():
|
|
||||||
"""Should match terms regardless of case."""
|
|
||||||
assert "artificial intelligence" in expand_query("AI research").lower()
|
|
||||||
assert "artificial intelligence" in expand_query("ai research").lower()
|
|
||||||
assert "artificial intelligence" in expand_query("Ai Research").lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_expand_query_word_boundaries():
|
|
||||||
"""Should only match whole words, not partial matches."""
|
|
||||||
# "mail" contains "ai" but shouldn't trigger expansion
|
|
||||||
result = expand_query("email server")
|
|
||||||
assert result == "email server"
|
|
||||||
|
|
||||||
# "claim" contains "ai" but shouldn't trigger expansion
|
|
||||||
result = expand_query("insurance claim")
|
|
||||||
assert result == "insurance claim"
|
|
||||||
|
|
||||||
|
|
||||||
def test_expand_query_multiple_terms():
|
|
||||||
"""Should expand multiple matching terms."""
|
|
||||||
result = expand_query("AI and ML applications")
|
|
||||||
assert "artificial intelligence" in result.lower()
|
|
||||||
assert "machine learning" in result.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_expand_query_preserves_original():
|
|
||||||
"""Should preserve original query text."""
|
|
||||||
original = "AI safety research"
|
|
||||||
result = expand_query(original)
|
|
||||||
assert result.startswith(original)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user