Refactor search: add LLM query analysis, extract constants

- Add query_analysis.py for LLM-based query preprocessing
  - Detects modalities from natural language ("on lesswrong" -> forum)
  - Cleans meta-language ("I remember reading..." -> core query)
  - Generates query variants for better recall
  - Dynamically discovers modalities and domains from database

- Extract constants to constants.py
  - STOPWORDS, RRF_K, boost values, etc.
  - Cleaner separation of configuration from logic

- Refactor search_chunks into focused helper functions
  - _run_llm_analysis: parallel query analysis + HyDE
  - _apply_query_analysis: apply analysis results
  - _build_search_data: construct search data with variants
  - _run_searches: embedding + BM25 with RRF fusion
  - _fetch_chunks: database retrieval with scoring
  - _apply_boosts: title, popularity, recency boosts
  - _apply_reranking: cross-encoder reranking

- Remove redundant regex-based modality detection
- Remove static QUERY_EXPANSIONS (LLM handles this better)
- Add comprehensive tests for query_analysis module

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
mruwnik 2025-12-21 14:43:17 +00:00
parent 60e6e18284
commit 782b56939f
6 changed files with 1088 additions and 363 deletions

View File

@ -0,0 +1,75 @@
"""
Constants for search functionality.
"""
# Reciprocal Rank Fusion constant (k parameter)
# Higher values reduce the influence of top-ranked documents
# 60 is the standard value from the original RRF paper
RRF_K = 60
# Multiplier for internal search limit before fusion
# We search for more candidates than requested, fuse scores, then return top N
# This helps find results that rank well in one method but not the other
CANDIDATE_MULTIPLIER = 5
# How many candidates to pass to reranker (multiplier of final limit)
# Higher = more accurate but slower and more expensive
RERANK_CANDIDATE_MULTIPLIER = 3
# Bonus for chunks containing query terms (added to RRF score)
QUERY_TERM_BOOST = 0.005
# Bonus when query terms match the source title (stronger signal)
TITLE_MATCH_BOOST = 0.01
# Bonus multiplier for popularity (applied as: score * (1 + POPULARITY_BOOST * (popularity - 1)))
# This gives a small boost to popular items without dominating relevance
POPULARITY_BOOST = 0.02
# Recency boost settings
# Maximum bonus for brand new content (additive)
RECENCY_BOOST_MAX = 0.005
# Half-life in days: content loses half its recency boost every N days
RECENCY_HALF_LIFE_DAYS = 90
# Common words to ignore when checking for query term presence
STOPWORDS = frozenset({
# Articles
"a", "an", "the",
# Be verbs
"is", "are", "was", "were", "be", "been", "being",
# Have verbs
"have", "has", "had",
# Do verbs
"do", "does", "did",
# Modal verbs
"will", "would", "could", "should", "may", "might", "must", "shall", "can",
"need", "dare", "ought", "used",
# Prepositions
"to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into",
"through", "during", "before", "after", "above", "below", "between", "under",
# Adverbs
"again", "further", "then", "once", "here", "there", "when", "where", "why", "how",
# Quantifiers
"all", "each", "few", "more", "most", "other", "some", "such",
# Negation
"no", "nor", "not",
# Other common words
"only", "own", "same", "so", "than", "too", "very", "just",
# Conjunctions
"and", "but", "if", "or", "because", "until", "while", "although", "though",
# Relative pronouns
"what", "which", "who", "whom",
# Demonstratives
"this", "that", "these", "those",
# Personal pronouns
"i", "me", "my", "myself",
"we", "our", "ours", "ourselves",
"you", "your", "yours", "yourself", "yourselves",
"he", "him", "his", "himself",
"she", "her", "hers", "herself",
"it", "its", "itself",
"they", "them", "their", "theirs", "themselves",
# Misc common words
"about", "get", "got", "getting", "like", "also",
})

View File

@ -0,0 +1,327 @@
"""
LLM-based query analysis for intelligent search preprocessing.
Uses a fast LLM (Haiku) to analyze natural language queries and extract:
- Modalities: content types to search (forum, book, comic, etc.)
- Source hints: author names, domains, or specific sources
- Cleaned query: the actual search terms with meta-language removed
- Query variants: alternative phrasings to search
This runs in parallel with HyDE for maximum efficiency.
"""
import asyncio
import json
import logging
import textwrap
import time
from dataclasses import dataclass, field
from typing import Optional
from urllib.parse import urlparse
from sqlalchemy import func, text
from sqlalchemy.exc import SQLAlchemyError
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import SourceItem
from memory.common.llms import create_provider, LLMSettings, Message
logger = logging.getLogger(__name__)
# Threshold for listing specific sources (if fewer than N distinct domains, list them)
MAX_DOMAINS_TO_LIST = 10
@dataclass
class ModalityInfo:
"""Information about a modality from the database."""
name: str
count: int
domains: list[str] = field(default_factory=list)
source_count: int | None = None # For modalities with parent entities (e.g., books)
@property
def description(self) -> str:
"""Build description including domains if there are few enough."""
if self.source_count:
base = f"{self.source_count:,} sources ({self.count:,} sections)"
else:
base = f"{self.count:,} items"
if self.domains:
return f"{base} from: {', '.join(self.domains)}"
return base
# Cache for database-derived information
_modality_cache: dict[str, ModalityInfo] = {}
_cache_timestamp: float = 0
_CACHE_TTL_SECONDS = 3600 # Refresh every hour
def _get_tables_with_url_column(db) -> list[str]:
"""Query database schema to find tables that have a 'url' column."""
result = db.execute(text("""
SELECT table_name FROM information_schema.columns
WHERE column_name = 'url' AND table_schema = 'public'
"""))
return [row[0] for row in result]
def _get_modality_domains(db) -> dict[str, list[str]]:
"""Get domains for each modality that has URL data."""
tables = _get_tables_with_url_column(db)
if not tables:
return {}
# Build a UNION query to get modality + url from all URL-containing tables
union_parts = []
for table in tables:
union_parts.append(
f"SELECT s.modality, t.url FROM source_item s "
f"JOIN {table} t ON s.id = t.id WHERE t.url IS NOT NULL"
)
if not union_parts:
return {}
query = " UNION ALL ".join(union_parts)
try:
result = db.execute(text(query))
rows = list(result)
except SQLAlchemyError as e:
logger.debug(f"Database error getting modality URLs: {e}")
return {}
# Group URLs by modality and extract domains
modality_domains: dict[str, set[str]] = {}
for modality, url in rows:
if not modality or not url:
continue
try:
domain = urlparse(url).netloc
if domain:
domain = domain.replace("www.", "")
if modality not in modality_domains:
modality_domains[modality] = set()
modality_domains[modality].add(domain)
except ValueError:
continue
# Only return domains for modalities with few enough to list
return {
modality: sorted(domains)
for modality, domains in modality_domains.items()
if len(domains) <= MAX_DOMAINS_TO_LIST
}
def _get_source_counts(db) -> dict[str, int]:
"""Get distinct source counts for modalities with parent entities."""
try:
# Books: count distinct book_id from book_section
result = db.execute(text(
"SELECT COUNT(DISTINCT book_id) FROM book_section"
))
book_count = result.scalar() or 0
return {"book": book_count}
except SQLAlchemyError:
return {}
def _refresh_modality_cache() -> None:
"""Query database to find modalities with actual content."""
global _modality_cache, _cache_timestamp
try:
with make_session() as db:
# Get modality counts
results = (
db.query(SourceItem.modality, func.count(SourceItem.id))
.group_by(SourceItem.modality)
.order_by(func.count(SourceItem.id).desc())
.all()
)
# Get domains for modalities with URLs (single query)
modality_domains = _get_modality_domains(db)
# Get source counts for modalities with parent entities
source_counts = _get_source_counts(db)
_modality_cache = {}
for modality, count in results:
if modality and count > 0:
_modality_cache[modality] = ModalityInfo(
name=modality,
count=count,
domains=modality_domains.get(modality, []),
source_count=source_counts.get(modality),
)
_cache_timestamp = time.time()
logger.debug(f"Refreshed modality cache: {list(_modality_cache.keys())}")
except SQLAlchemyError as e:
logger.warning(f"Database error refreshing modality cache: {e}")
def _get_available_modalities() -> dict[str, ModalityInfo]:
"""Get modalities with content, refreshing cache if needed."""
global _cache_timestamp
if time.time() - _cache_timestamp > _CACHE_TTL_SECONDS or not _modality_cache:
_refresh_modality_cache()
return _modality_cache
def _build_prompt() -> str:
"""Build the query analysis prompt with actual available modalities."""
modalities = _get_available_modalities()
if not modalities:
modality_section = " (no content indexed yet)"
else:
lines = []
for info in modalities.values():
lines.append(f" - {info.name}: {info.description}")
modality_section = "\n".join(lines)
modality_names = list(modalities.keys()) if modalities else []
return (
textwrap.dedent("""
Analyze this search query and extract structured information.
The user is searching a personal knowledge base containing:
{modality_section}
Return a JSON object:
{{
"modalities": [], // From: {modality_names} (empty = search all)
"sources": [], // Specific sources/authors mentioned
"cleaned_query": "", // Query with meta-language removed
"query_variants": [] // 1-3 alternative phrasings
}}
Guidelines:
- "on lesswrong" -> forum, "comic about" -> comic, etc.
- Remove "there was something about", "I remember reading", etc.
- Generate useful query variants
Return ONLY valid JSON.
""")
.strip()
.format(
modality_section=modality_section,
modality_names=modality_names,
)
)
@dataclass
class QueryAnalysis:
"""Result of LLM-based query analysis."""
modalities: set[str] = field(default_factory=set)
sources: list[str] = field(default_factory=list)
cleaned_query: str = ""
query_variants: list[str] = field(default_factory=list)
success: bool = False
# Cache for recent analyses
_analysis_cache: dict[str, QueryAnalysis] = {}
_CACHE_MAX_SIZE = 100
async def analyze_query(
query: str,
model: Optional[str] = None,
timeout: float = 3.0,
) -> QueryAnalysis:
"""
Analyze a search query using an LLM to extract search parameters.
Args:
query: The user's natural language search query
model: LLM model to use (defaults to SUMMARIZER_MODEL, ideally Haiku)
timeout: Maximum time to wait for LLM response
Returns:
QueryAnalysis with extracted modalities, sources, cleaned query, and variants
"""
# Check cache first
cache_key = query.lower().strip()
if cache_key in _analysis_cache:
logger.debug(f"Query analysis cache hit for: {query[:50]}...")
return _analysis_cache[cache_key]
result = QueryAnalysis(cleaned_query=query)
try:
provider = create_provider(model=model or settings.SUMMARIZER_MODEL)
messages = [Message.user(text=f"Query: {query}")]
llm_settings = LLMSettings(
temperature=0.1, # Low temperature for consistent structured output
max_tokens=300,
)
response = await asyncio.wait_for(
provider.agenerate(
messages=messages,
system_prompt=_build_prompt(),
settings=llm_settings,
),
timeout=timeout,
)
if response:
# Parse JSON response
response = response.strip()
# Handle markdown code blocks
if response.startswith("```"):
response = response.split("```")[1]
if response.startswith("json"):
response = response[4:]
response = response.strip()
try:
data = json.loads(response)
result.modalities = set(data.get("modalities", []))
result.sources = data.get("sources", [])
result.cleaned_query = data.get("cleaned_query", query)
result.query_variants = data.get("query_variants", [])
result.success = True
logger.debug(
f"Query analysis: '{query[:40]}...' -> "
f"modalities={result.modalities}, "
f"cleaned='{result.cleaned_query[:30]}...'"
)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse query analysis JSON: {e}")
result.cleaned_query = query
# Cache the result
if len(_analysis_cache) >= _CACHE_MAX_SIZE:
keys_to_remove = list(_analysis_cache.keys())[: _CACHE_MAX_SIZE // 2]
for key in keys_to_remove:
del _analysis_cache[key]
_analysis_cache[cache_key] = result
except asyncio.TimeoutError:
logger.warning(f"Query analysis timed out for: {query[:50]}...")
except Exception as e:
logger.error(f"Query analysis failed: {e}")
return result

View File

@ -5,10 +5,8 @@ Search endpoints for the knowledge base API.
import asyncio
import logging
import math
import re
from collections import defaultdict
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy.orm import load_only
from memory.common import extract, settings
from memory.common.db.connection import make_session
@ -16,6 +14,17 @@ from memory.common.db.models import Chunk, SourceItem
from memory.common.collections import ALL_COLLECTIONS
from memory.api.search.embeddings import search_chunks_embeddings
from memory.api.search import scorer
from memory.api.search.constants import (
RRF_K,
CANDIDATE_MULTIPLIER,
RERANK_CANDIDATE_MULTIPLIER,
QUERY_TERM_BOOST,
TITLE_MATCH_BOOST,
POPULARITY_BOOST,
RECENCY_BOOST_MAX,
RECENCY_HALF_LIFE_DAYS,
STOPWORDS,
)
if settings.ENABLE_BM25_SEARCH:
from memory.api.search.bm25 import search_bm25_chunks
@ -26,6 +35,7 @@ if settings.ENABLE_HYDE_EXPANSION:
if settings.ENABLE_RERANKING:
from memory.api.search.rerank import rerank_chunks
from memory.api.search.query_analysis import analyze_query, QueryAnalysis
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
# Default config for when none is provided
@ -33,202 +43,6 @@ _DEFAULT_CONFIG = SearchConfig()
logger = logging.getLogger(__name__)
# Reciprocal Rank Fusion constant (k parameter)
# Higher values reduce the influence of top-ranked documents
# 60 is the standard value from the original RRF paper
RRF_K = 60
# Multiplier for internal search limit before fusion
# We search for more candidates than requested, fuse scores, then return top N
# This helps find results that rank well in one method but not the other
CANDIDATE_MULTIPLIER = 5
# How many candidates to pass to reranker (multiplier of final limit)
# Higher = more accurate but slower and more expensive
RERANK_CANDIDATE_MULTIPLIER = 3
# Bonus for chunks containing query terms (added to RRF score)
QUERY_TERM_BOOST = 0.005
# Bonus when query terms match the source title (stronger signal)
TITLE_MATCH_BOOST = 0.01
# Bonus multiplier for popularity (applied as: score * (1 + POPULARITY_BOOST * (popularity - 1)))
# This gives a small boost to popular items without dominating relevance
POPULARITY_BOOST = 0.02
# Recency boost settings
# Maximum bonus for brand new content (additive)
RECENCY_BOOST_MAX = 0.005
# Half-life in days: content loses half its recency boost every N days
RECENCY_HALF_LIFE_DAYS = 90
# Query expansion: map abbreviations/acronyms to full forms
# These help match when users search for "ML" but documents say "machine learning"
QUERY_EXPANSIONS: dict[str, list[str]] = {
# AI/ML abbreviations
"ai": ["artificial intelligence"],
"ml": ["machine learning"],
"dl": ["deep learning"],
"nlp": ["natural language processing"],
"cv": ["computer vision"],
"rl": ["reinforcement learning"],
"llm": ["large language model", "language model"],
"gpt": ["generative pretrained transformer", "language model"],
"nn": ["neural network"],
"cnn": ["convolutional neural network"],
"rnn": ["recurrent neural network"],
"lstm": ["long short term memory"],
"gan": ["generative adversarial network"],
"rag": ["retrieval augmented generation"],
# Rationality/EA terms
"ea": ["effective altruism"],
"lw": ["lesswrong", "less wrong"],
"gwwc": ["giving what we can"],
"agi": ["artificial general intelligence"],
"asi": ["artificial superintelligence"],
"fai": ["friendly ai", "ai alignment"],
"x-risk": ["existential risk"],
"xrisk": ["existential risk"],
"p(doom)": ["probability of doom", "ai risk"],
# Reverse mappings (full forms -> abbreviations)
"artificial intelligence": ["ai"],
"machine learning": ["ml"],
"deep learning": ["dl"],
"natural language processing": ["nlp"],
"computer vision": ["cv"],
"reinforcement learning": ["rl"],
"neural network": ["nn"],
"effective altruism": ["ea"],
"existential risk": ["x-risk", "xrisk"],
# Family relationships (bidirectional)
"father": ["son", "daughter", "child", "parent", "dad"],
"mother": ["son", "daughter", "child", "parent", "mom"],
"parent": ["child", "son", "daughter", "father", "mother"],
"son": ["father", "parent", "child"],
"daughter": ["mother", "parent", "child"],
"child": ["parent", "father", "mother"],
"dad": ["father", "son", "daughter", "child"],
"mom": ["mother", "son", "daughter", "child"],
}
# Modality detection patterns: map query phrases to collection names
# Each entry is (pattern, modalities, strip_pattern)
# - pattern: regex to match in query
# - modalities: set of collection names to filter to
# - strip_pattern: whether to remove the matched text from query
MODALITY_PATTERNS: list[tuple[str, set[str], bool]] = [
# Comics
(r"\b(comic|comics|webcomic|webcomics)\b", {"comic"}, True),
# Forum posts (LessWrong, EA Forum, etc.)
(r"\b(on\s+)?(lesswrong|lw|less\s+wrong)\b", {"forum"}, True),
(r"\b(on\s+)?(ea\s+forum|effective\s+altruism\s+forum)\b", {"forum"}, True),
(r"\b(on\s+)?(alignment\s+forum|af)\b", {"forum"}, True),
(r"\b(forum\s+post|lw\s+post|post\s+on)\b", {"forum"}, True),
# Books
(r"\b(in\s+a\s+book|in\s+the\s+book|book|chapter)\b", {"book"}, True),
# Blog posts / articles
(r"\b(blog\s+post|blog|article)\b", {"blog"}, True),
# Email
(r"\b(email|e-mail|mail)\b", {"mail"}, True),
# Photos / images
(r"\b(photo|photograph|picture|image)\b", {"photo"}, True),
# Documents
(r"\b(document|pdf|doc)\b", {"doc"}, True),
# Chat / messages
(r"\b(chat|message|discord|slack)\b", {"chat"}, True),
# Git
(r"\b(commit|git|pull\s+request|pr)\b", {"git"}, True),
]
# Meta-language patterns to strip (these don't indicate modality, just noise)
META_LANGUAGE_PATTERNS: list[str] = [
r"\bthere\s+was\s+(something|some|some\s+\w+|an?\s+\w+)\s+(about|on)\b",
r"\bi\s+remember\s+(reading|seeing|there\s+being)\s*(an?\s+)?",
r"\bi\s+(read|saw|found)\s+(something|an?\s+\w+)\s+about\b",
r"\bsomething\s+about\b",
r"\bsome\s+about\b",
r"\bthis\s+whole\s+\w+\s+thing\b",
r"\bthat\s+\w+\s+thing\b",
r"\bthat\s+about\b", # Clean up leftover "that about"
r"\ba\s+about\b", # Clean up leftover "a about"
r"\bthe\s+about\b", # Clean up leftover "the about"
r"\bthere\s+was\s+some\s+about\b", # Clean up leftover
]
def detect_modality_hints(query: str) -> tuple[str, set[str]]:
"""
Detect content type hints in query and extract modalities.
Returns:
(cleaned_query, detected_modalities)
- cleaned_query: query with modality indicators and meta-language removed
- detected_modalities: set of collection names detected from query
"""
query_lower = query.lower()
detected: set[str] = set()
cleaned = query
# First, detect and strip modality patterns
for pattern, modalities, strip in MODALITY_PATTERNS:
if re.search(pattern, query_lower, re.IGNORECASE):
detected.update(modalities)
if strip:
cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE)
# Strip meta-language patterns (regardless of modality detection)
for pattern in META_LANGUAGE_PATTERNS:
cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE)
# Clean up whitespace
cleaned = " ".join(cleaned.split())
return cleaned, detected
def expand_query(query: str) -> str:
"""
Expand query with synonyms and abbreviations.
This helps match documents that use different terminology for the same concept.
For example, "ML algorithms" -> "ML machine learning algorithms"
"""
query_lower = query.lower()
expansions = []
for term, synonyms in QUERY_EXPANSIONS.items():
# Check if term appears as a whole word in the query
# Use word boundaries to avoid matching partial words
pattern = r'\b' + re.escape(term) + r'\b'
if re.search(pattern, query_lower):
expansions.extend(synonyms)
if expansions:
# Add expansions to the original query
return query + " " + " ".join(expansions)
return query
# Common words to ignore when checking for query term presence
STOPWORDS = {
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
"have", "has", "had", "do", "does", "did", "will", "would", "could",
"should", "may", "might", "must", "shall", "can", "need", "dare",
"ought", "used", "to", "of", "in", "for", "on", "with", "at", "by",
"from", "as", "into", "through", "during", "before", "after", "above",
"below", "between", "under", "again", "further", "then", "once", "here",
"there", "when", "where", "why", "how", "all", "each", "few", "more",
"most", "other", "some", "such", "no", "nor", "not", "only", "own",
"same", "so", "than", "too", "very", "just", "and", "but", "if", "or",
"because", "until", "while", "although", "though", "after", "before",
"what", "which", "who", "whom", "this", "that", "these", "those", "i",
"me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your",
"yours", "yourself", "yourselves", "he", "him", "his", "himself", "she",
"her", "hers", "herself", "it", "its", "itself", "they", "them", "their",
"theirs", "themselves", "about", "get", "got", "getting", "like", "also",
}
def extract_query_terms(query: str) -> set[str]:
"""Extract meaningful terms from query, filtering stopwords."""
@ -270,7 +84,9 @@ def deduplicate_by_source(chunks: list[Chunk]) -> list[Chunk]:
source_id = chunk.source_id
if source_id not in best_by_source:
best_by_source[source_id] = chunk
elif (chunk.relevance_score or 0) > (best_by_source[source_id].relevance_score or 0):
elif (chunk.relevance_score or 0) > (
best_by_source[source_id].relevance_score or 0
):
best_by_source[source_id] = chunk
return list(best_by_source.values())
@ -294,9 +110,7 @@ def apply_source_boosts(
# Single query to fetch all source metadata
with make_session() as db:
sources = db.query(SourceItem).filter(
SourceItem.id.in_(source_ids)
).all()
sources = db.query(SourceItem).filter(SourceItem.id.in_(source_ids)).all()
source_map = {
s.id: {
"title": (getattr(s, "title", None) or "").lower(),
@ -369,7 +183,9 @@ def fuse_scores_rrf(
Dict mapping chunk IDs to RRF scores
"""
# Convert scores to ranks (1-indexed)
emb_ranked = sorted(embedding_scores.keys(), key=lambda x: embedding_scores[x], reverse=True)
emb_ranked = sorted(
embedding_scores.keys(), key=lambda x: embedding_scores[x], reverse=True
)
bm25_ranked = sorted(bm25_scores.keys(), key=lambda x: bm25_scores[x], reverse=True)
emb_ranks = {chunk_id: rank + 1 for rank, chunk_id in enumerate(emb_ranked)}
@ -393,84 +209,131 @@ def fuse_scores_rrf(
return fused
async def search_chunks(
async def _run_llm_analysis(
query_text: str,
use_query_analysis: bool,
use_hyde: bool,
) -> tuple[QueryAnalysis | None, str | None]:
"""
Run LLM-based query analysis and/or HyDE expansion in parallel.
Returns:
(analysis_result, hyde_doc) tuple
"""
analysis_result: QueryAnalysis | None = None
hyde_doc: str | None = None
if not (use_query_analysis or use_hyde):
return analysis_result, hyde_doc
tasks = []
if use_query_analysis:
tasks.append(("analysis", analyze_query(query_text, timeout=3.0)))
if use_hyde and len(query_text.split()) >= 4:
tasks.append(
("hyde", expand_query_hyde(query_text, timeout=settings.HYDE_TIMEOUT))
)
if not tasks:
return analysis_result, hyde_doc
try:
results = await asyncio.gather(
*[task for _, task in tasks], return_exceptions=True
)
for i, (name, _) in enumerate(tasks):
result = results[i]
if isinstance(result, Exception):
logger.warning(f"{name} failed: {result}")
continue
if name == "analysis" and result:
analysis_result = result
elif name == "hyde" and result:
hyde_doc = result
except Exception as e:
logger.warning(f"Parallel LLM calls failed: {e}")
return analysis_result, hyde_doc
def _apply_query_analysis(
analysis_result: QueryAnalysis,
query_text: str,
data: list[extract.DataChunk],
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = {},
timeout: int = 2,
config: SearchConfig = _DEFAULT_CONFIG,
) -> list[Chunk]:
modalities: set[str],
) -> tuple[str, list[extract.DataChunk], set[str], list[str]]:
"""
Search chunks using embedding similarity and optionally BM25.
Apply query analysis results to modify query, data, and modalities.
Combines results using weighted score fusion, giving bonus to documents
that match both semantically and lexically.
If HyDE is enabled, also generates a hypothetical document from the query
and includes it in the embedding search for better semantic matching.
Enhancement flags in config override global settings when set:
- useBm25: Enable BM25 lexical search
- useHyde: Enable HyDE query expansion
- useReranking: Enable cross-encoder reranking
- useQueryExpansion: Enable synonym/abbreviation expansion
- useModalityDetection: Detect content type hints from query
Returns:
(updated_query_text, updated_data, updated_modalities, query_variants)
"""
# Resolve enhancement flags: config overrides global settings
use_bm25 = config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
use_hyde = config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
use_reranking = config.useReranking if config.useReranking is not None else settings.ENABLE_RERANKING
use_query_expansion = config.useQueryExpansion if config.useQueryExpansion is not None else True
use_modality_detection = config.useModalityDetection if config.useModalityDetection is not None else False
query_variants: list[str] = []
# Search for more candidates than requested, fuse scores, then return top N
# This helps find results that rank well in one method but not the other
internal_limit = limit * CANDIDATE_MULTIPLIER
if not (analysis_result and analysis_result.success):
return query_text, data, modalities, query_variants
# Extract query text
query_text = " ".join(
c for chunk in data for c in chunk.data if isinstance(c, str)
)
# Use detected modalities if any
if analysis_result.modalities:
modalities = analysis_result.modalities
logger.debug(f"Query analysis modalities: {modalities}")
# Detect modality hints and clean query if enabled
if use_modality_detection:
cleaned_query, detected_modalities = detect_modality_hints(query_text)
if detected_modalities:
# Override passed modalities with detected ones
modalities = detected_modalities
logger.debug(f"Modality detection: '{query_text[:50]}...' -> modalities={detected_modalities}")
if cleaned_query != query_text:
logger.debug(f"Query cleaning: '{query_text[:50]}...' -> '{cleaned_query[:50]}...'")
query_text = cleaned_query
# Update data with cleaned query for downstream processing
data = [extract.DataChunk(data=[cleaned_query])]
# Use cleaned query
if analysis_result.cleaned_query and analysis_result.cleaned_query != query_text:
logger.debug(
f"Query analysis cleaning: '{query_text[:40]}...' -> '{analysis_result.cleaned_query[:40]}...'"
)
query_text = analysis_result.cleaned_query
data = [extract.DataChunk(data=[analysis_result.cleaned_query])]
if use_query_expansion:
expanded_query = expand_query(query_text)
# If query was expanded, use expanded version for search
if expanded_query != query_text:
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
search_data = [extract.DataChunk(data=[expanded_query])]
else:
search_data = list(data) # Copy to avoid modifying original
else:
search_data = list(data)
# Collect query variants
query_variants.extend(analysis_result.query_variants)
# Apply HyDE expansion if enabled
if use_hyde:
# Only expand queries with 4+ words (short queries are usually specific enough)
if len(query_text.split()) >= 4:
try:
hyde_doc = await expand_query_hyde(
query_text, timeout=settings.HYDE_TIMEOUT
)
if hyde_doc:
logger.debug(f"HyDE expansion: '{query_text[:30]}...' -> '{hyde_doc[:50]}...'")
search_data.append(extract.DataChunk(data=[hyde_doc]))
except Exception as e:
logger.warning(f"HyDE expansion failed, using original query: {e}")
return query_text, data, modalities, query_variants
def _build_search_data(
data: list[extract.DataChunk],
hyde_doc: str | None,
query_variants: list[str],
query_text: str,
) -> list[extract.DataChunk]:
"""
Build the list of data chunks to search with.
Includes original query, HyDE expansion, and query variants.
"""
search_data = list(data)
# Add HyDE expansion if we got one
if hyde_doc:
logger.debug(f"HyDE expansion: '{query_text[:30]}...' -> '{hyde_doc[:50]}...'")
search_data.append(extract.DataChunk(data=[hyde_doc]))
# Add query variants from analysis (limit to 3)
for variant in query_variants[:3]:
search_data.append(extract.DataChunk(data=[variant]))
return search_data
async def _run_searches(
search_data: list[extract.DataChunk],
data: list[extract.DataChunk],
modalities: set[str],
internal_limit: int,
filters: SearchFilters,
timeout: int,
use_bm25: bool,
) -> dict[str, float]:
"""
Run embedding and optionally BM25 searches, returning fused scores.
"""
# Run embedding search
embedding_scores = await search_chunks_embeddings(
search_data, modalities, internal_limit, filters, timeout
@ -487,14 +350,25 @@ async def search_chunks(
logger.warning("BM25 search timed out, using embedding results only")
# Fuse scores from both methods using Reciprocal Rank Fusion
fused_scores = fuse_scores_rrf(embedding_scores, bm25_scores)
return fuse_scores_rrf(embedding_scores, bm25_scores)
def _fetch_chunks(
fused_scores: dict[str, float],
limit: int,
use_reranking: bool,
) -> list[Chunk]:
"""
Fetch chunk objects from database and set their relevance scores.
"""
if not fused_scores:
return []
# Sort by score and take top results
# If reranking is enabled, fetch more candidates for the reranker to work with
sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
sorted_ids = sorted(
fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True
)
if use_reranking:
fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER
else:
@ -522,29 +396,125 @@ async def search_chunks(
db.expunge_all()
# Extract query text for boosting and reranking
return chunks
def _apply_boosts(
chunks: list[Chunk],
data: list[extract.DataChunk],
) -> None:
"""
Apply query term, title, popularity, and recency boosts to chunks.
"""
if not chunks:
return
# Extract query text for boosting
query_text = " ".join(
c for chunk in data for c in chunk.data if isinstance(c, str)
)
# Apply query term presence boost
if chunks and query_text.strip():
if query_text.strip():
query_terms = extract_query_terms(query_text)
apply_query_term_boost(chunks, query_terms)
# Apply title + popularity boosts (single DB query)
apply_source_boosts(chunks, query_terms)
elif chunks:
else:
# No query terms, just apply popularity boost
apply_source_boosts(chunks, set())
# Rerank using cross-encoder for better precision
if use_reranking and chunks and query_text.strip():
try:
chunks = await rerank_chunks(
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
)
except Exception as e:
logger.warning(f"Reranking failed, using RRF order: {e}")
async def _apply_reranking(
chunks: list[Chunk],
query_text: str,
limit: int,
use_reranking: bool,
) -> list[Chunk]:
"""
Apply cross-encoder reranking if enabled.
"""
if not (use_reranking and chunks and query_text.strip()):
return chunks
try:
return await rerank_chunks(
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
)
except Exception as e:
logger.warning(f"Reranking failed, using RRF order: {e}")
return chunks
async def search_chunks(
data: list[extract.DataChunk],
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = {},
timeout: int = 2,
config: SearchConfig = _DEFAULT_CONFIG,
) -> list[Chunk]:
"""
Search chunks using embedding similarity and optionally BM25.
Combines results using weighted score fusion, giving bonus to documents
that match both semantically and lexically.
If HyDE is enabled, also generates a hypothetical document from the query
and includes it in the embedding search for better semantic matching.
Enhancement flags in config override global settings when set:
- useBm25: Enable BM25 lexical search
- useHyde: Enable HyDE query expansion
- useReranking: Enable cross-encoder reranking
- useQueryAnalysis: LLM-based query analysis (extracts modalities, cleans query, generates variants)
"""
# Resolve enhancement flags: config overrides global settings
use_bm25 = (
config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
)
use_hyde = (
config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
)
use_reranking = (
config.useReranking
if config.useReranking is not None
else settings.ENABLE_RERANKING
)
use_query_analysis = (
config.useQueryAnalysis if config.useQueryAnalysis is not None else False
)
internal_limit = limit * CANDIDATE_MULTIPLIER
# Extract query text
query_text = " ".join(c for chunk in data for c in chunk.data if isinstance(c, str))
# Run LLM-based operations in parallel (query analysis + HyDE)
analysis_result, hyde_doc = await _run_llm_analysis(
query_text, use_query_analysis, use_hyde
)
# Apply query analysis results
query_text, data, modalities, query_variants = _apply_query_analysis(
analysis_result, query_text, data, modalities
)
# Build search data with HyDE and variants
search_data = _build_search_data(data, hyde_doc, query_variants, query_text)
# Run searches and fuse scores
fused_scores = await _run_searches(
search_data, data, modalities, internal_limit, filters, timeout, use_bm25
)
# Fetch chunks from database
chunks = _fetch_chunks(fused_scores, limit, use_reranking)
# Apply various boosts
_apply_boosts(chunks, data)
# Apply reranking if enabled
chunks = await _apply_reranking(chunks, query_text, limit, use_reranking)
return chunks

View File

@ -87,8 +87,7 @@ class SearchConfig(BaseModel):
useBm25: Optional[bool] = None
useHyde: Optional[bool] = None
useReranking: Optional[bool] = None
useQueryExpansion: Optional[bool] = None
useModalityDetection: Optional[bool] = None
useQueryAnalysis: Optional[bool] = None # LLM-based query analysis (Haiku)
def model_post_init(self, __context) -> None:
# Enforce reasonable limits

View File

@ -0,0 +1,428 @@
"""Tests for query_analysis module."""
import json
import pytest
from unittest.mock import patch, Mock, AsyncMock
from memory.api.search.query_analysis import (
ModalityInfo,
QueryAnalysis,
_build_prompt,
_get_available_modalities,
analyze_query,
MAX_DOMAINS_TO_LIST,
)
class TestModalityInfo:
"""Tests for ModalityInfo dataclass."""
def test_description_items_only(self):
"""Basic description with just count."""
info = ModalityInfo(name="forum", count=1000)
assert info.description == "1,000 items"
def test_description_with_domains(self):
"""Description includes domains when present."""
info = ModalityInfo(
name="forum", count=1000, domains=["lesswrong.com", "example.com"]
)
assert info.description == "1,000 items from: lesswrong.com, example.com"
def test_description_with_source_count(self):
"""Description shows sources and sections for parent entities."""
info = ModalityInfo(name="book", count=22000, source_count=400)
assert info.description == "400 sources (22,000 sections)"
def test_description_with_source_count_and_domains(self):
"""Description shows sources, sections, and domains."""
info = ModalityInfo(
name="comic",
count=8000,
source_count=50,
domains=["xkcd.com", "smbc-comics.com"],
)
assert (
info.description
== "50 sources (8,000 sections) from: xkcd.com, smbc-comics.com"
)
def test_description_formats_large_numbers(self):
"""Numbers are formatted with commas."""
info = ModalityInfo(name="book", count=1234567, source_count=9876)
assert info.description == "9,876 sources (1,234,567 sections)"
class TestQueryAnalysis:
"""Tests for QueryAnalysis dataclass."""
def test_default_values(self):
"""QueryAnalysis has sensible defaults."""
result = QueryAnalysis()
assert result.modalities == set()
assert result.sources == []
assert result.cleaned_query == ""
assert result.query_variants == []
assert result.success is False
def test_with_values(self):
"""QueryAnalysis stores provided values."""
result = QueryAnalysis(
modalities={"forum", "book"},
sources=["lesswrong.com"],
cleaned_query="test query",
query_variants=["alternative query"],
success=True,
)
assert result.modalities == {"forum", "book"}
assert result.sources == ["lesswrong.com"]
assert result.cleaned_query == "test query"
assert result.query_variants == ["alternative query"]
assert result.success is True
class TestBuildPrompt:
"""Tests for _build_prompt function."""
def test_builds_prompt_with_modalities(self):
"""Prompt includes modality information."""
mock_modalities = {
"forum": ModalityInfo(
name="forum", count=20000, domains=["lesswrong.com"]
),
"book": ModalityInfo(name="book", count=22000, source_count=400),
}
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value=mock_modalities,
):
prompt = _build_prompt()
assert "forum" in prompt
assert "book" in prompt
assert "20,000 items from: lesswrong.com" in prompt
assert "400 sources (22,000 sections)" in prompt
assert "modalities" in prompt
assert "cleaned_query" in prompt
def test_builds_prompt_empty_modalities(self):
"""Prompt handles no modalities gracefully."""
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={},
):
prompt = _build_prompt()
assert "no content indexed yet" in prompt
def test_prompt_contains_json_structure(self):
"""Prompt includes expected JSON structure."""
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={"test": ModalityInfo(name="test", count=100)},
):
prompt = _build_prompt()
assert '"modalities": []' in prompt
assert '"sources": []' in prompt
assert '"cleaned_query": ""' in prompt
assert '"query_variants": []' in prompt
def test_prompt_contains_guidelines(self):
"""Prompt includes usage guidelines."""
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={"forum": ModalityInfo(name="forum", count=100)},
):
prompt = _build_prompt()
assert "lesswrong" in prompt.lower()
assert "comic" in prompt.lower()
assert "Remove" in prompt
assert "Return ONLY valid JSON" in prompt
class TestAnalyzeQuery:
"""Tests for analyze_query async function."""
@pytest.mark.asyncio
async def test_returns_analysis_on_success(self):
"""Successfully parses LLM JSON response."""
mock_response = json.dumps(
{
"modalities": ["forum"],
"sources": ["lesswrong.com"],
"cleaned_query": "rationality concepts",
"query_variants": ["rational thinking", "epistemic rationality"],
}
)
mock_provider = Mock()
mock_provider.agenerate = AsyncMock(return_value=mock_response)
with patch(
"memory.api.search.query_analysis.create_provider",
return_value=mock_provider,
):
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={"forum": ModalityInfo(name="forum", count=100)},
):
# Clear any cached results
from memory.api.search import query_analysis
query_analysis._analysis_cache = {}
result = await analyze_query("something on lesswrong about rationality")
assert result.success is True
assert result.modalities == {"forum"}
assert result.sources == ["lesswrong.com"]
assert result.cleaned_query == "rationality concepts"
assert "rational thinking" in result.query_variants
@pytest.mark.asyncio
async def test_handles_markdown_code_blocks(self):
"""Strips markdown code blocks from response."""
mock_response = """```json
{
"modalities": ["book"],
"sources": [],
"cleaned_query": "test query",
"query_variants": []
}
```"""
mock_provider = Mock()
mock_provider.agenerate = AsyncMock(return_value=mock_response)
with patch(
"memory.api.search.query_analysis.create_provider",
return_value=mock_provider,
):
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={"book": ModalityInfo(name="book", count=100)},
):
from memory.api.search import query_analysis
query_analysis._analysis_cache = {}
result = await analyze_query("find something in a book")
assert result.success is True
assert result.modalities == {"book"}
@pytest.mark.asyncio
async def test_handles_invalid_json(self):
"""Returns default result on invalid JSON."""
mock_provider = Mock()
mock_provider.agenerate = AsyncMock(return_value="not valid json {{{")
with patch(
"memory.api.search.query_analysis.create_provider",
return_value=mock_provider,
):
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={},
):
from memory.api.search import query_analysis
query_analysis._analysis_cache = {}
result = await analyze_query("test query")
assert result.success is False
assert result.cleaned_query == "test query"
@pytest.mark.asyncio
async def test_handles_timeout(self):
"""Returns default result on timeout."""
import asyncio
async def slow_response(*args, **kwargs):
await asyncio.sleep(10)
return "{}"
mock_provider = Mock()
mock_provider.agenerate = slow_response
with patch(
"memory.api.search.query_analysis.create_provider",
return_value=mock_provider,
):
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={},
):
from memory.api.search import query_analysis
query_analysis._analysis_cache = {}
result = await analyze_query("test query", timeout=0.01)
assert result.success is False
assert result.cleaned_query == "test query"
@pytest.mark.asyncio
async def test_caches_results(self):
"""Caches analysis results for repeated queries."""
mock_response = json.dumps(
{
"modalities": [],
"sources": [],
"cleaned_query": "cached query",
"query_variants": [],
}
)
mock_provider = Mock()
mock_provider.agenerate = AsyncMock(return_value=mock_response)
with patch(
"memory.api.search.query_analysis.create_provider",
return_value=mock_provider,
):
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={},
):
from memory.api.search import query_analysis
query_analysis._analysis_cache = {}
# First call
result1 = await analyze_query("test query for caching")
# Second call (should use cache)
result2 = await analyze_query("test query for caching")
# Provider should only be called once
assert mock_provider.agenerate.call_count == 1
assert result1.cleaned_query == result2.cleaned_query
@pytest.mark.asyncio
async def test_cache_case_insensitive(self):
"""Cache key is case-insensitive."""
mock_response = json.dumps(
{
"modalities": [],
"sources": [],
"cleaned_query": "test",
"query_variants": [],
}
)
mock_provider = Mock()
mock_provider.agenerate = AsyncMock(return_value=mock_response)
with patch(
"memory.api.search.query_analysis.create_provider",
return_value=mock_provider,
):
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={},
):
from memory.api.search import query_analysis
query_analysis._analysis_cache = {}
await analyze_query("Test Query")
await analyze_query("test query")
await analyze_query("TEST QUERY")
# All variations should hit the same cache entry
assert mock_provider.agenerate.call_count == 1
@pytest.mark.asyncio
async def test_handles_empty_response(self):
"""Handles empty LLM response gracefully."""
mock_provider = Mock()
mock_provider.agenerate = AsyncMock(return_value="")
with patch(
"memory.api.search.query_analysis.create_provider",
return_value=mock_provider,
):
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={},
):
from memory.api.search import query_analysis
query_analysis._analysis_cache = {}
result = await analyze_query("test query")
assert result.success is False
assert result.cleaned_query == "test query"
@pytest.mark.asyncio
async def test_handles_none_response(self):
"""Handles None LLM response gracefully."""
mock_provider = Mock()
mock_provider.agenerate = AsyncMock(return_value=None)
with patch(
"memory.api.search.query_analysis.create_provider",
return_value=mock_provider,
):
with patch(
"memory.api.search.query_analysis._get_available_modalities",
return_value={},
):
from memory.api.search import query_analysis
query_analysis._analysis_cache = {}
result = await analyze_query("test query")
assert result.success is False
class TestGetAvailableModalities:
"""Tests for _get_available_modalities function."""
def test_uses_cache_when_fresh(self):
"""Returns cached data when cache is fresh."""
import time
from memory.api.search import query_analysis
# Set up cache
query_analysis._modality_cache = {
"test": ModalityInfo(name="test", count=100)
}
query_analysis._cache_timestamp = time.time()
with patch(
"memory.api.search.query_analysis._refresh_modality_cache"
) as mock_refresh:
result = _get_available_modalities()
mock_refresh.assert_not_called()
assert "test" in result
def test_refreshes_cache_when_stale(self):
"""Refreshes cache when TTL has expired."""
import time
from memory.api.search import query_analysis
# Set up stale cache
query_analysis._modality_cache = {}
query_analysis._cache_timestamp = time.time() - 7200 # 2 hours ago
with patch(
"memory.api.search.query_analysis._refresh_modality_cache"
) as mock_refresh:
_get_available_modalities()
mock_refresh.assert_called_once()
def test_refreshes_cache_when_empty(self):
"""Refreshes cache when cache is empty."""
import time
from memory.api.search import query_analysis
query_analysis._modality_cache = {}
query_analysis._cache_timestamp = time.time()
with patch(
"memory.api.search.query_analysis._refresh_modality_cache"
) as mock_refresh:
_get_available_modalities()
mock_refresh.assert_called_once()

View File

@ -15,8 +15,9 @@ from memory.api.search.search import (
apply_title_boost,
apply_popularity_boost,
apply_source_boosts,
expand_query,
fuse_scores_rrf,
)
from memory.api.search.constants import (
STOPWORDS,
QUERY_TERM_BOOST,
TITLE_MATCH_BOOST,
@ -591,78 +592,3 @@ def test_recency_boost_ordering(mock_make_session):
# Newer content should have higher score
assert chunks[0].relevance_score > chunks[1].relevance_score
# ============================================================================
# expand_query tests
# ============================================================================
@pytest.mark.parametrize(
"query,expected_expansion",
[
# AI/ML abbreviations
("ML algorithms", "artificial intelligence"), # Not "ML" -> won't have AI
("AI safety research", "artificial intelligence"),
("NLP models for text", "natural language processing"),
("deep learning vs DL", "deep learning"),
# Rationality/EA terms
("EA organizations", "effective altruism"),
("AGI timeline predictions", "artificial general intelligence"),
("x-risk reduction", "existential risk"),
# Reverse mappings
("machine learning basics", "ml"),
("artificial intelligence ethics", "ai"),
],
)
def test_expand_query_adds_synonyms(query, expected_expansion):
"""Should expand abbreviations and add synonyms."""
result = expand_query(query)
assert expected_expansion in result.lower()
assert query in result # Original query preserved
@pytest.mark.parametrize(
"query",
[
"hello world", # No expansions
"python programming", # No expansions
"database optimization", # No expansions
],
)
def test_expand_query_no_match(query):
"""Should return original query when no expansions match."""
result = expand_query(query)
assert result == query
def test_expand_query_case_insensitive():
"""Should match terms regardless of case."""
assert "artificial intelligence" in expand_query("AI research").lower()
assert "artificial intelligence" in expand_query("ai research").lower()
assert "artificial intelligence" in expand_query("Ai Research").lower()
def test_expand_query_word_boundaries():
"""Should only match whole words, not partial matches."""
# "mail" contains "ai" but shouldn't trigger expansion
result = expand_query("email server")
assert result == "email server"
# "claim" contains "ai" but shouldn't trigger expansion
result = expand_query("insurance claim")
assert result == "insurance claim"
def test_expand_query_multiple_terms():
"""Should expand multiple matching terms."""
result = expand_query("AI and ML applications")
assert "artificial intelligence" in result.lower()
assert "machine learning" in result.lower()
def test_expand_query_preserves_original():
"""Should preserve original query text."""
original = "AI safety research"
result = expand_query(original)
assert result.startswith(original)