From 782b56939fec2382746cfda29c0d30e1b42dd48e Mon Sep 17 00:00:00 2001 From: mruwnik Date: Sun, 21 Dec 2025 14:43:17 +0000 Subject: [PATCH] Refactor search: add LLM query analysis, extract constants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add query_analysis.py for LLM-based query preprocessing - Detects modalities from natural language ("on lesswrong" -> forum) - Cleans meta-language ("I remember reading..." -> core query) - Generates query variants for better recall - Dynamically discovers modalities and domains from database - Extract constants to constants.py - STOPWORDS, RRF_K, boost values, etc. - Cleaner separation of configuration from logic - Refactor search_chunks into focused helper functions - _run_llm_analysis: parallel query analysis + HyDE - _apply_query_analysis: apply analysis results - _build_search_data: construct search data with variants - _run_searches: embedding + BM25 with RRF fusion - _fetch_chunks: database retrieval with scoring - _apply_boosts: title, popularity, recency boosts - _apply_reranking: cross-encoder reranking - Remove redundant regex-based modality detection - Remove static QUERY_EXPANSIONS (LLM handles this better) - Add comprehensive tests for query_analysis module 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/memory/api/search/constants.py | 75 +++ src/memory/api/search/query_analysis.py | 327 +++++++++++ src/memory/api/search/search.py | 540 +++++++++--------- src/memory/api/search/types.py | 3 +- .../memory/api/search/test_query_analysis.py | 428 ++++++++++++++ tests/memory/api/search/test_search.py | 78 +-- 6 files changed, 1088 insertions(+), 363 deletions(-) create mode 100644 src/memory/api/search/constants.py create mode 100644 src/memory/api/search/query_analysis.py create mode 100644 tests/memory/api/search/test_query_analysis.py diff --git a/src/memory/api/search/constants.py b/src/memory/api/search/constants.py new file mode 100644 index 0000000..1de5830 --- /dev/null +++ b/src/memory/api/search/constants.py @@ -0,0 +1,75 @@ +""" +Constants for search functionality. +""" + +# Reciprocal Rank Fusion constant (k parameter) +# Higher values reduce the influence of top-ranked documents +# 60 is the standard value from the original RRF paper +RRF_K = 60 + +# Multiplier for internal search limit before fusion +# We search for more candidates than requested, fuse scores, then return top N +# This helps find results that rank well in one method but not the other +CANDIDATE_MULTIPLIER = 5 + +# How many candidates to pass to reranker (multiplier of final limit) +# Higher = more accurate but slower and more expensive +RERANK_CANDIDATE_MULTIPLIER = 3 + +# Bonus for chunks containing query terms (added to RRF score) +QUERY_TERM_BOOST = 0.005 + +# Bonus when query terms match the source title (stronger signal) +TITLE_MATCH_BOOST = 0.01 + +# Bonus multiplier for popularity (applied as: score * (1 + POPULARITY_BOOST * (popularity - 1))) +# This gives a small boost to popular items without dominating relevance +POPULARITY_BOOST = 0.02 + +# Recency boost settings +# Maximum bonus for brand new content (additive) +RECENCY_BOOST_MAX = 0.005 +# Half-life in days: content loses half its recency boost every N days +RECENCY_HALF_LIFE_DAYS = 90 + +# Common words to ignore when checking for query term presence +STOPWORDS = frozenset({ + # Articles + "a", "an", "the", + # Be verbs + "is", "are", "was", "were", "be", "been", "being", + # Have verbs + "have", "has", "had", + # Do verbs + "do", "does", "did", + # Modal verbs + "will", "would", "could", "should", "may", "might", "must", "shall", "can", + "need", "dare", "ought", "used", + # Prepositions + "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", + "through", "during", "before", "after", "above", "below", "between", "under", + # Adverbs + "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", + # Quantifiers + "all", "each", "few", "more", "most", "other", "some", "such", + # Negation + "no", "nor", "not", + # Other common words + "only", "own", "same", "so", "than", "too", "very", "just", + # Conjunctions + "and", "but", "if", "or", "because", "until", "while", "although", "though", + # Relative pronouns + "what", "which", "who", "whom", + # Demonstratives + "this", "that", "these", "those", + # Personal pronouns + "i", "me", "my", "myself", + "we", "our", "ours", "ourselves", + "you", "your", "yours", "yourself", "yourselves", + "he", "him", "his", "himself", + "she", "her", "hers", "herself", + "it", "its", "itself", + "they", "them", "their", "theirs", "themselves", + # Misc common words + "about", "get", "got", "getting", "like", "also", +}) diff --git a/src/memory/api/search/query_analysis.py b/src/memory/api/search/query_analysis.py new file mode 100644 index 0000000..4e64424 --- /dev/null +++ b/src/memory/api/search/query_analysis.py @@ -0,0 +1,327 @@ +""" +LLM-based query analysis for intelligent search preprocessing. + +Uses a fast LLM (Haiku) to analyze natural language queries and extract: +- Modalities: content types to search (forum, book, comic, etc.) +- Source hints: author names, domains, or specific sources +- Cleaned query: the actual search terms with meta-language removed +- Query variants: alternative phrasings to search + +This runs in parallel with HyDE for maximum efficiency. +""" + +import asyncio +import json +import logging +import textwrap +import time +from dataclasses import dataclass, field +from typing import Optional +from urllib.parse import urlparse + +from sqlalchemy import func, text +from sqlalchemy.exc import SQLAlchemyError + +from memory.common import settings +from memory.common.db.connection import make_session +from memory.common.db.models import SourceItem +from memory.common.llms import create_provider, LLMSettings, Message + +logger = logging.getLogger(__name__) + +# Threshold for listing specific sources (if fewer than N distinct domains, list them) +MAX_DOMAINS_TO_LIST = 10 + + + +@dataclass +class ModalityInfo: + """Information about a modality from the database.""" + + name: str + count: int + domains: list[str] = field(default_factory=list) + source_count: int | None = None # For modalities with parent entities (e.g., books) + + @property + def description(self) -> str: + """Build description including domains if there are few enough.""" + if self.source_count: + base = f"{self.source_count:,} sources ({self.count:,} sections)" + else: + base = f"{self.count:,} items" + + if self.domains: + return f"{base} from: {', '.join(self.domains)}" + return base + + +# Cache for database-derived information +_modality_cache: dict[str, ModalityInfo] = {} +_cache_timestamp: float = 0 +_CACHE_TTL_SECONDS = 3600 # Refresh every hour + + +def _get_tables_with_url_column(db) -> list[str]: + """Query database schema to find tables that have a 'url' column.""" + result = db.execute(text(""" + SELECT table_name FROM information_schema.columns + WHERE column_name = 'url' AND table_schema = 'public' + """)) + return [row[0] for row in result] + + +def _get_modality_domains(db) -> dict[str, list[str]]: + """Get domains for each modality that has URL data.""" + tables = _get_tables_with_url_column(db) + if not tables: + return {} + + # Build a UNION query to get modality + url from all URL-containing tables + union_parts = [] + for table in tables: + union_parts.append( + f"SELECT s.modality, t.url FROM source_item s " + f"JOIN {table} t ON s.id = t.id WHERE t.url IS NOT NULL" + ) + + if not union_parts: + return {} + + query = " UNION ALL ".join(union_parts) + + try: + result = db.execute(text(query)) + rows = list(result) + except SQLAlchemyError as e: + logger.debug(f"Database error getting modality URLs: {e}") + return {} + + # Group URLs by modality and extract domains + modality_domains: dict[str, set[str]] = {} + for modality, url in rows: + if not modality or not url: + continue + try: + domain = urlparse(url).netloc + if domain: + domain = domain.replace("www.", "") + if modality not in modality_domains: + modality_domains[modality] = set() + modality_domains[modality].add(domain) + except ValueError: + continue + + # Only return domains for modalities with few enough to list + return { + modality: sorted(domains) + for modality, domains in modality_domains.items() + if len(domains) <= MAX_DOMAINS_TO_LIST + } + + +def _get_source_counts(db) -> dict[str, int]: + """Get distinct source counts for modalities with parent entities.""" + try: + # Books: count distinct book_id from book_section + result = db.execute(text( + "SELECT COUNT(DISTINCT book_id) FROM book_section" + )) + book_count = result.scalar() or 0 + return {"book": book_count} + except SQLAlchemyError: + return {} + + +def _refresh_modality_cache() -> None: + """Query database to find modalities with actual content.""" + global _modality_cache, _cache_timestamp + + try: + with make_session() as db: + # Get modality counts + results = ( + db.query(SourceItem.modality, func.count(SourceItem.id)) + .group_by(SourceItem.modality) + .order_by(func.count(SourceItem.id).desc()) + .all() + ) + + # Get domains for modalities with URLs (single query) + modality_domains = _get_modality_domains(db) + + # Get source counts for modalities with parent entities + source_counts = _get_source_counts(db) + + _modality_cache = {} + for modality, count in results: + if modality and count > 0: + _modality_cache[modality] = ModalityInfo( + name=modality, + count=count, + domains=modality_domains.get(modality, []), + source_count=source_counts.get(modality), + ) + + _cache_timestamp = time.time() + logger.debug(f"Refreshed modality cache: {list(_modality_cache.keys())}") + + except SQLAlchemyError as e: + logger.warning(f"Database error refreshing modality cache: {e}") + + +def _get_available_modalities() -> dict[str, ModalityInfo]: + """Get modalities with content, refreshing cache if needed.""" + global _cache_timestamp + + if time.time() - _cache_timestamp > _CACHE_TTL_SECONDS or not _modality_cache: + _refresh_modality_cache() + + return _modality_cache + + +def _build_prompt() -> str: + """Build the query analysis prompt with actual available modalities.""" + modalities = _get_available_modalities() + + if not modalities: + modality_section = " (no content indexed yet)" + else: + lines = [] + for info in modalities.values(): + lines.append(f" - {info.name}: {info.description}") + modality_section = "\n".join(lines) + + modality_names = list(modalities.keys()) if modalities else [] + + return ( + textwrap.dedent(""" + Analyze this search query and extract structured information. + + The user is searching a personal knowledge base containing: + {modality_section} + + Return a JSON object: + {{ + "modalities": [], // From: {modality_names} (empty = search all) + "sources": [], // Specific sources/authors mentioned + "cleaned_query": "", // Query with meta-language removed + "query_variants": [] // 1-3 alternative phrasings + }} + + Guidelines: + - "on lesswrong" -> forum, "comic about" -> comic, etc. + - Remove "there was something about", "I remember reading", etc. + - Generate useful query variants + + Return ONLY valid JSON. + """) + .strip() + .format( + modality_section=modality_section, + modality_names=modality_names, + ) + ) + + +@dataclass +class QueryAnalysis: + """Result of LLM-based query analysis.""" + + modalities: set[str] = field(default_factory=set) + sources: list[str] = field(default_factory=list) + cleaned_query: str = "" + query_variants: list[str] = field(default_factory=list) + success: bool = False + + +# Cache for recent analyses +_analysis_cache: dict[str, QueryAnalysis] = {} +_CACHE_MAX_SIZE = 100 + + +async def analyze_query( + query: str, + model: Optional[str] = None, + timeout: float = 3.0, +) -> QueryAnalysis: + """ + Analyze a search query using an LLM to extract search parameters. + + Args: + query: The user's natural language search query + model: LLM model to use (defaults to SUMMARIZER_MODEL, ideally Haiku) + timeout: Maximum time to wait for LLM response + + Returns: + QueryAnalysis with extracted modalities, sources, cleaned query, and variants + """ + # Check cache first + cache_key = query.lower().strip() + if cache_key in _analysis_cache: + logger.debug(f"Query analysis cache hit for: {query[:50]}...") + return _analysis_cache[cache_key] + + result = QueryAnalysis(cleaned_query=query) + + try: + provider = create_provider(model=model or settings.SUMMARIZER_MODEL) + + messages = [Message.user(text=f"Query: {query}")] + + llm_settings = LLMSettings( + temperature=0.1, # Low temperature for consistent structured output + max_tokens=300, + ) + + response = await asyncio.wait_for( + provider.agenerate( + messages=messages, + system_prompt=_build_prompt(), + settings=llm_settings, + ), + timeout=timeout, + ) + + if response: + # Parse JSON response + response = response.strip() + # Handle markdown code blocks + if response.startswith("```"): + response = response.split("```")[1] + if response.startswith("json"): + response = response[4:] + response = response.strip() + + try: + data = json.loads(response) + + result.modalities = set(data.get("modalities", [])) + result.sources = data.get("sources", []) + result.cleaned_query = data.get("cleaned_query", query) + result.query_variants = data.get("query_variants", []) + result.success = True + + logger.debug( + f"Query analysis: '{query[:40]}...' -> " + f"modalities={result.modalities}, " + f"cleaned='{result.cleaned_query[:30]}...'" + ) + + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse query analysis JSON: {e}") + result.cleaned_query = query + + # Cache the result + if len(_analysis_cache) >= _CACHE_MAX_SIZE: + keys_to_remove = list(_analysis_cache.keys())[: _CACHE_MAX_SIZE // 2] + for key in keys_to_remove: + del _analysis_cache[key] + _analysis_cache[cache_key] = result + + except asyncio.TimeoutError: + logger.warning(f"Query analysis timed out for: {query[:50]}...") + except Exception as e: + logger.error(f"Query analysis failed: {e}") + + return result diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py index 1476c3b..b566e29 100644 --- a/src/memory/api/search/search.py +++ b/src/memory/api/search/search.py @@ -5,10 +5,8 @@ Search endpoints for the knowledge base API. import asyncio import logging import math -import re from collections import defaultdict from datetime import datetime, timezone -from typing import Optional from sqlalchemy.orm import load_only from memory.common import extract, settings from memory.common.db.connection import make_session @@ -16,6 +14,17 @@ from memory.common.db.models import Chunk, SourceItem from memory.common.collections import ALL_COLLECTIONS from memory.api.search.embeddings import search_chunks_embeddings from memory.api.search import scorer +from memory.api.search.constants import ( + RRF_K, + CANDIDATE_MULTIPLIER, + RERANK_CANDIDATE_MULTIPLIER, + QUERY_TERM_BOOST, + TITLE_MATCH_BOOST, + POPULARITY_BOOST, + RECENCY_BOOST_MAX, + RECENCY_HALF_LIFE_DAYS, + STOPWORDS, +) if settings.ENABLE_BM25_SEARCH: from memory.api.search.bm25 import search_bm25_chunks @@ -26,6 +35,7 @@ if settings.ENABLE_HYDE_EXPANSION: if settings.ENABLE_RERANKING: from memory.api.search.rerank import rerank_chunks +from memory.api.search.query_analysis import analyze_query, QueryAnalysis from memory.api.search.types import SearchConfig, SearchFilters, SearchResult # Default config for when none is provided @@ -33,202 +43,6 @@ _DEFAULT_CONFIG = SearchConfig() logger = logging.getLogger(__name__) -# Reciprocal Rank Fusion constant (k parameter) -# Higher values reduce the influence of top-ranked documents -# 60 is the standard value from the original RRF paper -RRF_K = 60 - -# Multiplier for internal search limit before fusion -# We search for more candidates than requested, fuse scores, then return top N -# This helps find results that rank well in one method but not the other -CANDIDATE_MULTIPLIER = 5 - -# How many candidates to pass to reranker (multiplier of final limit) -# Higher = more accurate but slower and more expensive -RERANK_CANDIDATE_MULTIPLIER = 3 - -# Bonus for chunks containing query terms (added to RRF score) -QUERY_TERM_BOOST = 0.005 - -# Bonus when query terms match the source title (stronger signal) -TITLE_MATCH_BOOST = 0.01 - -# Bonus multiplier for popularity (applied as: score * (1 + POPULARITY_BOOST * (popularity - 1))) -# This gives a small boost to popular items without dominating relevance -POPULARITY_BOOST = 0.02 - -# Recency boost settings -# Maximum bonus for brand new content (additive) -RECENCY_BOOST_MAX = 0.005 -# Half-life in days: content loses half its recency boost every N days -RECENCY_HALF_LIFE_DAYS = 90 - -# Query expansion: map abbreviations/acronyms to full forms -# These help match when users search for "ML" but documents say "machine learning" -QUERY_EXPANSIONS: dict[str, list[str]] = { - # AI/ML abbreviations - "ai": ["artificial intelligence"], - "ml": ["machine learning"], - "dl": ["deep learning"], - "nlp": ["natural language processing"], - "cv": ["computer vision"], - "rl": ["reinforcement learning"], - "llm": ["large language model", "language model"], - "gpt": ["generative pretrained transformer", "language model"], - "nn": ["neural network"], - "cnn": ["convolutional neural network"], - "rnn": ["recurrent neural network"], - "lstm": ["long short term memory"], - "gan": ["generative adversarial network"], - "rag": ["retrieval augmented generation"], - # Rationality/EA terms - "ea": ["effective altruism"], - "lw": ["lesswrong", "less wrong"], - "gwwc": ["giving what we can"], - "agi": ["artificial general intelligence"], - "asi": ["artificial superintelligence"], - "fai": ["friendly ai", "ai alignment"], - "x-risk": ["existential risk"], - "xrisk": ["existential risk"], - "p(doom)": ["probability of doom", "ai risk"], - # Reverse mappings (full forms -> abbreviations) - "artificial intelligence": ["ai"], - "machine learning": ["ml"], - "deep learning": ["dl"], - "natural language processing": ["nlp"], - "computer vision": ["cv"], - "reinforcement learning": ["rl"], - "neural network": ["nn"], - "effective altruism": ["ea"], - "existential risk": ["x-risk", "xrisk"], - # Family relationships (bidirectional) - "father": ["son", "daughter", "child", "parent", "dad"], - "mother": ["son", "daughter", "child", "parent", "mom"], - "parent": ["child", "son", "daughter", "father", "mother"], - "son": ["father", "parent", "child"], - "daughter": ["mother", "parent", "child"], - "child": ["parent", "father", "mother"], - "dad": ["father", "son", "daughter", "child"], - "mom": ["mother", "son", "daughter", "child"], -} - -# Modality detection patterns: map query phrases to collection names -# Each entry is (pattern, modalities, strip_pattern) -# - pattern: regex to match in query -# - modalities: set of collection names to filter to -# - strip_pattern: whether to remove the matched text from query -MODALITY_PATTERNS: list[tuple[str, set[str], bool]] = [ - # Comics - (r"\b(comic|comics|webcomic|webcomics)\b", {"comic"}, True), - # Forum posts (LessWrong, EA Forum, etc.) - (r"\b(on\s+)?(lesswrong|lw|less\s+wrong)\b", {"forum"}, True), - (r"\b(on\s+)?(ea\s+forum|effective\s+altruism\s+forum)\b", {"forum"}, True), - (r"\b(on\s+)?(alignment\s+forum|af)\b", {"forum"}, True), - (r"\b(forum\s+post|lw\s+post|post\s+on)\b", {"forum"}, True), - # Books - (r"\b(in\s+a\s+book|in\s+the\s+book|book|chapter)\b", {"book"}, True), - # Blog posts / articles - (r"\b(blog\s+post|blog|article)\b", {"blog"}, True), - # Email - (r"\b(email|e-mail|mail)\b", {"mail"}, True), - # Photos / images - (r"\b(photo|photograph|picture|image)\b", {"photo"}, True), - # Documents - (r"\b(document|pdf|doc)\b", {"doc"}, True), - # Chat / messages - (r"\b(chat|message|discord|slack)\b", {"chat"}, True), - # Git - (r"\b(commit|git|pull\s+request|pr)\b", {"git"}, True), -] - -# Meta-language patterns to strip (these don't indicate modality, just noise) -META_LANGUAGE_PATTERNS: list[str] = [ - r"\bthere\s+was\s+(something|some|some\s+\w+|an?\s+\w+)\s+(about|on)\b", - r"\bi\s+remember\s+(reading|seeing|there\s+being)\s*(an?\s+)?", - r"\bi\s+(read|saw|found)\s+(something|an?\s+\w+)\s+about\b", - r"\bsomething\s+about\b", - r"\bsome\s+about\b", - r"\bthis\s+whole\s+\w+\s+thing\b", - r"\bthat\s+\w+\s+thing\b", - r"\bthat\s+about\b", # Clean up leftover "that about" - r"\ba\s+about\b", # Clean up leftover "a about" - r"\bthe\s+about\b", # Clean up leftover "the about" - r"\bthere\s+was\s+some\s+about\b", # Clean up leftover -] - - -def detect_modality_hints(query: str) -> tuple[str, set[str]]: - """ - Detect content type hints in query and extract modalities. - - Returns: - (cleaned_query, detected_modalities) - - cleaned_query: query with modality indicators and meta-language removed - - detected_modalities: set of collection names detected from query - """ - query_lower = query.lower() - detected: set[str] = set() - cleaned = query - - # First, detect and strip modality patterns - for pattern, modalities, strip in MODALITY_PATTERNS: - if re.search(pattern, query_lower, re.IGNORECASE): - detected.update(modalities) - if strip: - cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE) - - # Strip meta-language patterns (regardless of modality detection) - for pattern in META_LANGUAGE_PATTERNS: - cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE) - - # Clean up whitespace - cleaned = " ".join(cleaned.split()) - - return cleaned, detected - - -def expand_query(query: str) -> str: - """ - Expand query with synonyms and abbreviations. - - This helps match documents that use different terminology for the same concept. - For example, "ML algorithms" -> "ML machine learning algorithms" - """ - query_lower = query.lower() - expansions = [] - - for term, synonyms in QUERY_EXPANSIONS.items(): - # Check if term appears as a whole word in the query - # Use word boundaries to avoid matching partial words - pattern = r'\b' + re.escape(term) + r'\b' - if re.search(pattern, query_lower): - expansions.extend(synonyms) - - if expansions: - # Add expansions to the original query - return query + " " + " ".join(expansions) - return query - - -# Common words to ignore when checking for query term presence -STOPWORDS = { - "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", - "have", "has", "had", "do", "does", "did", "will", "would", "could", - "should", "may", "might", "must", "shall", "can", "need", "dare", - "ought", "used", "to", "of", "in", "for", "on", "with", "at", "by", - "from", "as", "into", "through", "during", "before", "after", "above", - "below", "between", "under", "again", "further", "then", "once", "here", - "there", "when", "where", "why", "how", "all", "each", "few", "more", - "most", "other", "some", "such", "no", "nor", "not", "only", "own", - "same", "so", "than", "too", "very", "just", "and", "but", "if", "or", - "because", "until", "while", "although", "though", "after", "before", - "what", "which", "who", "whom", "this", "that", "these", "those", "i", - "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", - "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", - "her", "hers", "herself", "it", "its", "itself", "they", "them", "their", - "theirs", "themselves", "about", "get", "got", "getting", "like", "also", -} - def extract_query_terms(query: str) -> set[str]: """Extract meaningful terms from query, filtering stopwords.""" @@ -270,7 +84,9 @@ def deduplicate_by_source(chunks: list[Chunk]) -> list[Chunk]: source_id = chunk.source_id if source_id not in best_by_source: best_by_source[source_id] = chunk - elif (chunk.relevance_score or 0) > (best_by_source[source_id].relevance_score or 0): + elif (chunk.relevance_score or 0) > ( + best_by_source[source_id].relevance_score or 0 + ): best_by_source[source_id] = chunk return list(best_by_source.values()) @@ -294,9 +110,7 @@ def apply_source_boosts( # Single query to fetch all source metadata with make_session() as db: - sources = db.query(SourceItem).filter( - SourceItem.id.in_(source_ids) - ).all() + sources = db.query(SourceItem).filter(SourceItem.id.in_(source_ids)).all() source_map = { s.id: { "title": (getattr(s, "title", None) or "").lower(), @@ -369,7 +183,9 @@ def fuse_scores_rrf( Dict mapping chunk IDs to RRF scores """ # Convert scores to ranks (1-indexed) - emb_ranked = sorted(embedding_scores.keys(), key=lambda x: embedding_scores[x], reverse=True) + emb_ranked = sorted( + embedding_scores.keys(), key=lambda x: embedding_scores[x], reverse=True + ) bm25_ranked = sorted(bm25_scores.keys(), key=lambda x: bm25_scores[x], reverse=True) emb_ranks = {chunk_id: rank + 1 for rank, chunk_id in enumerate(emb_ranked)} @@ -393,84 +209,131 @@ def fuse_scores_rrf( return fused -async def search_chunks( +async def _run_llm_analysis( + query_text: str, + use_query_analysis: bool, + use_hyde: bool, +) -> tuple[QueryAnalysis | None, str | None]: + """ + Run LLM-based query analysis and/or HyDE expansion in parallel. + + Returns: + (analysis_result, hyde_doc) tuple + """ + analysis_result: QueryAnalysis | None = None + hyde_doc: str | None = None + + if not (use_query_analysis or use_hyde): + return analysis_result, hyde_doc + + tasks = [] + + if use_query_analysis: + tasks.append(("analysis", analyze_query(query_text, timeout=3.0))) + + if use_hyde and len(query_text.split()) >= 4: + tasks.append( + ("hyde", expand_query_hyde(query_text, timeout=settings.HYDE_TIMEOUT)) + ) + + if not tasks: + return analysis_result, hyde_doc + + try: + results = await asyncio.gather( + *[task for _, task in tasks], return_exceptions=True + ) + + for i, (name, _) in enumerate(tasks): + result = results[i] + if isinstance(result, Exception): + logger.warning(f"{name} failed: {result}") + continue + + if name == "analysis" and result: + analysis_result = result + elif name == "hyde" and result: + hyde_doc = result + + except Exception as e: + logger.warning(f"Parallel LLM calls failed: {e}") + + return analysis_result, hyde_doc + + +def _apply_query_analysis( + analysis_result: QueryAnalysis, + query_text: str, data: list[extract.DataChunk], - modalities: set[str] = set(), - limit: int = 10, - filters: SearchFilters = {}, - timeout: int = 2, - config: SearchConfig = _DEFAULT_CONFIG, -) -> list[Chunk]: + modalities: set[str], +) -> tuple[str, list[extract.DataChunk], set[str], list[str]]: """ - Search chunks using embedding similarity and optionally BM25. + Apply query analysis results to modify query, data, and modalities. - Combines results using weighted score fusion, giving bonus to documents - that match both semantically and lexically. - - If HyDE is enabled, also generates a hypothetical document from the query - and includes it in the embedding search for better semantic matching. - - Enhancement flags in config override global settings when set: - - useBm25: Enable BM25 lexical search - - useHyde: Enable HyDE query expansion - - useReranking: Enable cross-encoder reranking - - useQueryExpansion: Enable synonym/abbreviation expansion - - useModalityDetection: Detect content type hints from query + Returns: + (updated_query_text, updated_data, updated_modalities, query_variants) """ - # Resolve enhancement flags: config overrides global settings - use_bm25 = config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH - use_hyde = config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION - use_reranking = config.useReranking if config.useReranking is not None else settings.ENABLE_RERANKING - use_query_expansion = config.useQueryExpansion if config.useQueryExpansion is not None else True - use_modality_detection = config.useModalityDetection if config.useModalityDetection is not None else False + query_variants: list[str] = [] - # Search for more candidates than requested, fuse scores, then return top N - # This helps find results that rank well in one method but not the other - internal_limit = limit * CANDIDATE_MULTIPLIER + if not (analysis_result and analysis_result.success): + return query_text, data, modalities, query_variants - # Extract query text - query_text = " ".join( - c for chunk in data for c in chunk.data if isinstance(c, str) - ) + # Use detected modalities if any + if analysis_result.modalities: + modalities = analysis_result.modalities + logger.debug(f"Query analysis modalities: {modalities}") - # Detect modality hints and clean query if enabled - if use_modality_detection: - cleaned_query, detected_modalities = detect_modality_hints(query_text) - if detected_modalities: - # Override passed modalities with detected ones - modalities = detected_modalities - logger.debug(f"Modality detection: '{query_text[:50]}...' -> modalities={detected_modalities}") - if cleaned_query != query_text: - logger.debug(f"Query cleaning: '{query_text[:50]}...' -> '{cleaned_query[:50]}...'") - query_text = cleaned_query - # Update data with cleaned query for downstream processing - data = [extract.DataChunk(data=[cleaned_query])] + # Use cleaned query + if analysis_result.cleaned_query and analysis_result.cleaned_query != query_text: + logger.debug( + f"Query analysis cleaning: '{query_text[:40]}...' -> '{analysis_result.cleaned_query[:40]}...'" + ) + query_text = analysis_result.cleaned_query + data = [extract.DataChunk(data=[analysis_result.cleaned_query])] - if use_query_expansion: - expanded_query = expand_query(query_text) - # If query was expanded, use expanded version for search - if expanded_query != query_text: - logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'") - search_data = [extract.DataChunk(data=[expanded_query])] - else: - search_data = list(data) # Copy to avoid modifying original - else: - search_data = list(data) + # Collect query variants + query_variants.extend(analysis_result.query_variants) - # Apply HyDE expansion if enabled - if use_hyde: - # Only expand queries with 4+ words (short queries are usually specific enough) - if len(query_text.split()) >= 4: - try: - hyde_doc = await expand_query_hyde( - query_text, timeout=settings.HYDE_TIMEOUT - ) - if hyde_doc: - logger.debug(f"HyDE expansion: '{query_text[:30]}...' -> '{hyde_doc[:50]}...'") - search_data.append(extract.DataChunk(data=[hyde_doc])) - except Exception as e: - logger.warning(f"HyDE expansion failed, using original query: {e}") + return query_text, data, modalities, query_variants + +def _build_search_data( + data: list[extract.DataChunk], + hyde_doc: str | None, + query_variants: list[str], + query_text: str, +) -> list[extract.DataChunk]: + """ + Build the list of data chunks to search with. + + Includes original query, HyDE expansion, and query variants. + """ + search_data = list(data) + + # Add HyDE expansion if we got one + if hyde_doc: + logger.debug(f"HyDE expansion: '{query_text[:30]}...' -> '{hyde_doc[:50]}...'") + search_data.append(extract.DataChunk(data=[hyde_doc])) + + # Add query variants from analysis (limit to 3) + for variant in query_variants[:3]: + search_data.append(extract.DataChunk(data=[variant])) + + return search_data + + +async def _run_searches( + search_data: list[extract.DataChunk], + data: list[extract.DataChunk], + modalities: set[str], + internal_limit: int, + filters: SearchFilters, + timeout: int, + use_bm25: bool, +) -> dict[str, float]: + """ + Run embedding and optionally BM25 searches, returning fused scores. + """ # Run embedding search embedding_scores = await search_chunks_embeddings( search_data, modalities, internal_limit, filters, timeout @@ -487,14 +350,25 @@ async def search_chunks( logger.warning("BM25 search timed out, using embedding results only") # Fuse scores from both methods using Reciprocal Rank Fusion - fused_scores = fuse_scores_rrf(embedding_scores, bm25_scores) + return fuse_scores_rrf(embedding_scores, bm25_scores) + +def _fetch_chunks( + fused_scores: dict[str, float], + limit: int, + use_reranking: bool, +) -> list[Chunk]: + """ + Fetch chunk objects from database and set their relevance scores. + """ if not fused_scores: return [] # Sort by score and take top results # If reranking is enabled, fetch more candidates for the reranker to work with - sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True) + sorted_ids = sorted( + fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True + ) if use_reranking: fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER else: @@ -522,29 +396,125 @@ async def search_chunks( db.expunge_all() - # Extract query text for boosting and reranking + return chunks + + +def _apply_boosts( + chunks: list[Chunk], + data: list[extract.DataChunk], +) -> None: + """ + Apply query term, title, popularity, and recency boosts to chunks. + """ + if not chunks: + return + + # Extract query text for boosting query_text = " ".join( c for chunk in data for c in chunk.data if isinstance(c, str) ) - # Apply query term presence boost - if chunks and query_text.strip(): + if query_text.strip(): query_terms = extract_query_terms(query_text) apply_query_term_boost(chunks, query_terms) # Apply title + popularity boosts (single DB query) apply_source_boosts(chunks, query_terms) - elif chunks: + else: # No query terms, just apply popularity boost apply_source_boosts(chunks, set()) - # Rerank using cross-encoder for better precision - if use_reranking and chunks and query_text.strip(): - try: - chunks = await rerank_chunks( - query_text, chunks, model=settings.RERANK_MODEL, top_k=limit - ) - except Exception as e: - logger.warning(f"Reranking failed, using RRF order: {e}") + +async def _apply_reranking( + chunks: list[Chunk], + query_text: str, + limit: int, + use_reranking: bool, +) -> list[Chunk]: + """ + Apply cross-encoder reranking if enabled. + """ + if not (use_reranking and chunks and query_text.strip()): + return chunks + + try: + return await rerank_chunks( + query_text, chunks, model=settings.RERANK_MODEL, top_k=limit + ) + except Exception as e: + logger.warning(f"Reranking failed, using RRF order: {e}") + return chunks + + +async def search_chunks( + data: list[extract.DataChunk], + modalities: set[str] = set(), + limit: int = 10, + filters: SearchFilters = {}, + timeout: int = 2, + config: SearchConfig = _DEFAULT_CONFIG, +) -> list[Chunk]: + """ + Search chunks using embedding similarity and optionally BM25. + + Combines results using weighted score fusion, giving bonus to documents + that match both semantically and lexically. + + If HyDE is enabled, also generates a hypothetical document from the query + and includes it in the embedding search for better semantic matching. + + Enhancement flags in config override global settings when set: + - useBm25: Enable BM25 lexical search + - useHyde: Enable HyDE query expansion + - useReranking: Enable cross-encoder reranking + - useQueryAnalysis: LLM-based query analysis (extracts modalities, cleans query, generates variants) + """ + # Resolve enhancement flags: config overrides global settings + use_bm25 = ( + config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH + ) + use_hyde = ( + config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION + ) + use_reranking = ( + config.useReranking + if config.useReranking is not None + else settings.ENABLE_RERANKING + ) + use_query_analysis = ( + config.useQueryAnalysis if config.useQueryAnalysis is not None else False + ) + + internal_limit = limit * CANDIDATE_MULTIPLIER + + # Extract query text + query_text = " ".join(c for chunk in data for c in chunk.data if isinstance(c, str)) + + # Run LLM-based operations in parallel (query analysis + HyDE) + analysis_result, hyde_doc = await _run_llm_analysis( + query_text, use_query_analysis, use_hyde + ) + + # Apply query analysis results + query_text, data, modalities, query_variants = _apply_query_analysis( + analysis_result, query_text, data, modalities + ) + + # Build search data with HyDE and variants + search_data = _build_search_data(data, hyde_doc, query_variants, query_text) + + # Run searches and fuse scores + fused_scores = await _run_searches( + search_data, data, modalities, internal_limit, filters, timeout, use_bm25 + ) + + # Fetch chunks from database + chunks = _fetch_chunks(fused_scores, limit, use_reranking) + + # Apply various boosts + _apply_boosts(chunks, data) + + # Apply reranking if enabled + chunks = await _apply_reranking(chunks, query_text, limit, use_reranking) return chunks diff --git a/src/memory/api/search/types.py b/src/memory/api/search/types.py index 462db56..47c9272 100644 --- a/src/memory/api/search/types.py +++ b/src/memory/api/search/types.py @@ -87,8 +87,7 @@ class SearchConfig(BaseModel): useBm25: Optional[bool] = None useHyde: Optional[bool] = None useReranking: Optional[bool] = None - useQueryExpansion: Optional[bool] = None - useModalityDetection: Optional[bool] = None + useQueryAnalysis: Optional[bool] = None # LLM-based query analysis (Haiku) def model_post_init(self, __context) -> None: # Enforce reasonable limits diff --git a/tests/memory/api/search/test_query_analysis.py b/tests/memory/api/search/test_query_analysis.py new file mode 100644 index 0000000..85ef343 --- /dev/null +++ b/tests/memory/api/search/test_query_analysis.py @@ -0,0 +1,428 @@ +"""Tests for query_analysis module.""" + +import json +import pytest +from unittest.mock import patch, Mock, AsyncMock + +from memory.api.search.query_analysis import ( + ModalityInfo, + QueryAnalysis, + _build_prompt, + _get_available_modalities, + analyze_query, + MAX_DOMAINS_TO_LIST, +) + + +class TestModalityInfo: + """Tests for ModalityInfo dataclass.""" + + def test_description_items_only(self): + """Basic description with just count.""" + info = ModalityInfo(name="forum", count=1000) + assert info.description == "1,000 items" + + def test_description_with_domains(self): + """Description includes domains when present.""" + info = ModalityInfo( + name="forum", count=1000, domains=["lesswrong.com", "example.com"] + ) + assert info.description == "1,000 items from: lesswrong.com, example.com" + + def test_description_with_source_count(self): + """Description shows sources and sections for parent entities.""" + info = ModalityInfo(name="book", count=22000, source_count=400) + assert info.description == "400 sources (22,000 sections)" + + def test_description_with_source_count_and_domains(self): + """Description shows sources, sections, and domains.""" + info = ModalityInfo( + name="comic", + count=8000, + source_count=50, + domains=["xkcd.com", "smbc-comics.com"], + ) + assert ( + info.description + == "50 sources (8,000 sections) from: xkcd.com, smbc-comics.com" + ) + + def test_description_formats_large_numbers(self): + """Numbers are formatted with commas.""" + info = ModalityInfo(name="book", count=1234567, source_count=9876) + assert info.description == "9,876 sources (1,234,567 sections)" + + +class TestQueryAnalysis: + """Tests for QueryAnalysis dataclass.""" + + def test_default_values(self): + """QueryAnalysis has sensible defaults.""" + result = QueryAnalysis() + assert result.modalities == set() + assert result.sources == [] + assert result.cleaned_query == "" + assert result.query_variants == [] + assert result.success is False + + def test_with_values(self): + """QueryAnalysis stores provided values.""" + result = QueryAnalysis( + modalities={"forum", "book"}, + sources=["lesswrong.com"], + cleaned_query="test query", + query_variants=["alternative query"], + success=True, + ) + assert result.modalities == {"forum", "book"} + assert result.sources == ["lesswrong.com"] + assert result.cleaned_query == "test query" + assert result.query_variants == ["alternative query"] + assert result.success is True + + +class TestBuildPrompt: + """Tests for _build_prompt function.""" + + def test_builds_prompt_with_modalities(self): + """Prompt includes modality information.""" + mock_modalities = { + "forum": ModalityInfo( + name="forum", count=20000, domains=["lesswrong.com"] + ), + "book": ModalityInfo(name="book", count=22000, source_count=400), + } + + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value=mock_modalities, + ): + prompt = _build_prompt() + + assert "forum" in prompt + assert "book" in prompt + assert "20,000 items from: lesswrong.com" in prompt + assert "400 sources (22,000 sections)" in prompt + assert "modalities" in prompt + assert "cleaned_query" in prompt + + def test_builds_prompt_empty_modalities(self): + """Prompt handles no modalities gracefully.""" + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={}, + ): + prompt = _build_prompt() + + assert "no content indexed yet" in prompt + + def test_prompt_contains_json_structure(self): + """Prompt includes expected JSON structure.""" + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={"test": ModalityInfo(name="test", count=100)}, + ): + prompt = _build_prompt() + + assert '"modalities": []' in prompt + assert '"sources": []' in prompt + assert '"cleaned_query": ""' in prompt + assert '"query_variants": []' in prompt + + def test_prompt_contains_guidelines(self): + """Prompt includes usage guidelines.""" + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={"forum": ModalityInfo(name="forum", count=100)}, + ): + prompt = _build_prompt() + + assert "lesswrong" in prompt.lower() + assert "comic" in prompt.lower() + assert "Remove" in prompt + assert "Return ONLY valid JSON" in prompt + + +class TestAnalyzeQuery: + """Tests for analyze_query async function.""" + + @pytest.mark.asyncio + async def test_returns_analysis_on_success(self): + """Successfully parses LLM JSON response.""" + mock_response = json.dumps( + { + "modalities": ["forum"], + "sources": ["lesswrong.com"], + "cleaned_query": "rationality concepts", + "query_variants": ["rational thinking", "epistemic rationality"], + } + ) + + mock_provider = Mock() + mock_provider.agenerate = AsyncMock(return_value=mock_response) + + with patch( + "memory.api.search.query_analysis.create_provider", + return_value=mock_provider, + ): + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={"forum": ModalityInfo(name="forum", count=100)}, + ): + # Clear any cached results + from memory.api.search import query_analysis + query_analysis._analysis_cache = {} + + result = await analyze_query("something on lesswrong about rationality") + + assert result.success is True + assert result.modalities == {"forum"} + assert result.sources == ["lesswrong.com"] + assert result.cleaned_query == "rationality concepts" + assert "rational thinking" in result.query_variants + + @pytest.mark.asyncio + async def test_handles_markdown_code_blocks(self): + """Strips markdown code blocks from response.""" + mock_response = """```json +{ + "modalities": ["book"], + "sources": [], + "cleaned_query": "test query", + "query_variants": [] +} +```""" + + mock_provider = Mock() + mock_provider.agenerate = AsyncMock(return_value=mock_response) + + with patch( + "memory.api.search.query_analysis.create_provider", + return_value=mock_provider, + ): + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={"book": ModalityInfo(name="book", count=100)}, + ): + from memory.api.search import query_analysis + query_analysis._analysis_cache = {} + + result = await analyze_query("find something in a book") + + assert result.success is True + assert result.modalities == {"book"} + + @pytest.mark.asyncio + async def test_handles_invalid_json(self): + """Returns default result on invalid JSON.""" + mock_provider = Mock() + mock_provider.agenerate = AsyncMock(return_value="not valid json {{{") + + with patch( + "memory.api.search.query_analysis.create_provider", + return_value=mock_provider, + ): + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={}, + ): + from memory.api.search import query_analysis + query_analysis._analysis_cache = {} + + result = await analyze_query("test query") + + assert result.success is False + assert result.cleaned_query == "test query" + + @pytest.mark.asyncio + async def test_handles_timeout(self): + """Returns default result on timeout.""" + import asyncio + + async def slow_response(*args, **kwargs): + await asyncio.sleep(10) + return "{}" + + mock_provider = Mock() + mock_provider.agenerate = slow_response + + with patch( + "memory.api.search.query_analysis.create_provider", + return_value=mock_provider, + ): + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={}, + ): + from memory.api.search import query_analysis + query_analysis._analysis_cache = {} + + result = await analyze_query("test query", timeout=0.01) + + assert result.success is False + assert result.cleaned_query == "test query" + + @pytest.mark.asyncio + async def test_caches_results(self): + """Caches analysis results for repeated queries.""" + mock_response = json.dumps( + { + "modalities": [], + "sources": [], + "cleaned_query": "cached query", + "query_variants": [], + } + ) + + mock_provider = Mock() + mock_provider.agenerate = AsyncMock(return_value=mock_response) + + with patch( + "memory.api.search.query_analysis.create_provider", + return_value=mock_provider, + ): + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={}, + ): + from memory.api.search import query_analysis + query_analysis._analysis_cache = {} + + # First call + result1 = await analyze_query("test query for caching") + # Second call (should use cache) + result2 = await analyze_query("test query for caching") + + # Provider should only be called once + assert mock_provider.agenerate.call_count == 1 + assert result1.cleaned_query == result2.cleaned_query + + @pytest.mark.asyncio + async def test_cache_case_insensitive(self): + """Cache key is case-insensitive.""" + mock_response = json.dumps( + { + "modalities": [], + "sources": [], + "cleaned_query": "test", + "query_variants": [], + } + ) + + mock_provider = Mock() + mock_provider.agenerate = AsyncMock(return_value=mock_response) + + with patch( + "memory.api.search.query_analysis.create_provider", + return_value=mock_provider, + ): + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={}, + ): + from memory.api.search import query_analysis + query_analysis._analysis_cache = {} + + await analyze_query("Test Query") + await analyze_query("test query") + await analyze_query("TEST QUERY") + + # All variations should hit the same cache entry + assert mock_provider.agenerate.call_count == 1 + + @pytest.mark.asyncio + async def test_handles_empty_response(self): + """Handles empty LLM response gracefully.""" + mock_provider = Mock() + mock_provider.agenerate = AsyncMock(return_value="") + + with patch( + "memory.api.search.query_analysis.create_provider", + return_value=mock_provider, + ): + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={}, + ): + from memory.api.search import query_analysis + query_analysis._analysis_cache = {} + + result = await analyze_query("test query") + + assert result.success is False + assert result.cleaned_query == "test query" + + @pytest.mark.asyncio + async def test_handles_none_response(self): + """Handles None LLM response gracefully.""" + mock_provider = Mock() + mock_provider.agenerate = AsyncMock(return_value=None) + + with patch( + "memory.api.search.query_analysis.create_provider", + return_value=mock_provider, + ): + with patch( + "memory.api.search.query_analysis._get_available_modalities", + return_value={}, + ): + from memory.api.search import query_analysis + query_analysis._analysis_cache = {} + + result = await analyze_query("test query") + + assert result.success is False + + +class TestGetAvailableModalities: + """Tests for _get_available_modalities function.""" + + def test_uses_cache_when_fresh(self): + """Returns cached data when cache is fresh.""" + import time + from memory.api.search import query_analysis + + # Set up cache + query_analysis._modality_cache = { + "test": ModalityInfo(name="test", count=100) + } + query_analysis._cache_timestamp = time.time() + + with patch( + "memory.api.search.query_analysis._refresh_modality_cache" + ) as mock_refresh: + result = _get_available_modalities() + + mock_refresh.assert_not_called() + assert "test" in result + + def test_refreshes_cache_when_stale(self): + """Refreshes cache when TTL has expired.""" + import time + from memory.api.search import query_analysis + + # Set up stale cache + query_analysis._modality_cache = {} + query_analysis._cache_timestamp = time.time() - 7200 # 2 hours ago + + with patch( + "memory.api.search.query_analysis._refresh_modality_cache" + ) as mock_refresh: + _get_available_modalities() + + mock_refresh.assert_called_once() + + def test_refreshes_cache_when_empty(self): + """Refreshes cache when cache is empty.""" + import time + from memory.api.search import query_analysis + + query_analysis._modality_cache = {} + query_analysis._cache_timestamp = time.time() + + with patch( + "memory.api.search.query_analysis._refresh_modality_cache" + ) as mock_refresh: + _get_available_modalities() + + mock_refresh.assert_called_once() diff --git a/tests/memory/api/search/test_search.py b/tests/memory/api/search/test_search.py index 09bfef6..ce1a563 100644 --- a/tests/memory/api/search/test_search.py +++ b/tests/memory/api/search/test_search.py @@ -15,8 +15,9 @@ from memory.api.search.search import ( apply_title_boost, apply_popularity_boost, apply_source_boosts, - expand_query, fuse_scores_rrf, +) +from memory.api.search.constants import ( STOPWORDS, QUERY_TERM_BOOST, TITLE_MATCH_BOOST, @@ -591,78 +592,3 @@ def test_recency_boost_ordering(mock_make_session): # Newer content should have higher score assert chunks[0].relevance_score > chunks[1].relevance_score - - -# ============================================================================ -# expand_query tests -# ============================================================================ - - -@pytest.mark.parametrize( - "query,expected_expansion", - [ - # AI/ML abbreviations - ("ML algorithms", "artificial intelligence"), # Not "ML" -> won't have AI - ("AI safety research", "artificial intelligence"), - ("NLP models for text", "natural language processing"), - ("deep learning vs DL", "deep learning"), - # Rationality/EA terms - ("EA organizations", "effective altruism"), - ("AGI timeline predictions", "artificial general intelligence"), - ("x-risk reduction", "existential risk"), - # Reverse mappings - ("machine learning basics", "ml"), - ("artificial intelligence ethics", "ai"), - ], -) -def test_expand_query_adds_synonyms(query, expected_expansion): - """Should expand abbreviations and add synonyms.""" - result = expand_query(query) - assert expected_expansion in result.lower() - assert query in result # Original query preserved - - -@pytest.mark.parametrize( - "query", - [ - "hello world", # No expansions - "python programming", # No expansions - "database optimization", # No expansions - ], -) -def test_expand_query_no_match(query): - """Should return original query when no expansions match.""" - result = expand_query(query) - assert result == query - - -def test_expand_query_case_insensitive(): - """Should match terms regardless of case.""" - assert "artificial intelligence" in expand_query("AI research").lower() - assert "artificial intelligence" in expand_query("ai research").lower() - assert "artificial intelligence" in expand_query("Ai Research").lower() - - -def test_expand_query_word_boundaries(): - """Should only match whole words, not partial matches.""" - # "mail" contains "ai" but shouldn't trigger expansion - result = expand_query("email server") - assert result == "email server" - - # "claim" contains "ai" but shouldn't trigger expansion - result = expand_query("insurance claim") - assert result == "insurance claim" - - -def test_expand_query_multiple_terms(): - """Should expand multiple matching terms.""" - result = expand_query("AI and ML applications") - assert "artificial intelligence" in result.lower() - assert "machine learning" in result.lower() - - -def test_expand_query_preserves_original(): - """Should preserve original query text.""" - original = "AI safety research" - result = expand_query(original) - assert result.startswith(original)