mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
Refactor search: add LLM query analysis, extract constants
- Add query_analysis.py for LLM-based query preprocessing
- Detects modalities from natural language ("on lesswrong" -> forum)
- Cleans meta-language ("I remember reading..." -> core query)
- Generates query variants for better recall
- Dynamically discovers modalities and domains from database
- Extract constants to constants.py
- STOPWORDS, RRF_K, boost values, etc.
- Cleaner separation of configuration from logic
- Refactor search_chunks into focused helper functions
- _run_llm_analysis: parallel query analysis + HyDE
- _apply_query_analysis: apply analysis results
- _build_search_data: construct search data with variants
- _run_searches: embedding + BM25 with RRF fusion
- _fetch_chunks: database retrieval with scoring
- _apply_boosts: title, popularity, recency boosts
- _apply_reranking: cross-encoder reranking
- Remove redundant regex-based modality detection
- Remove static QUERY_EXPANSIONS (LLM handles this better)
- Add comprehensive tests for query_analysis module
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
60e6e18284
commit
782b56939f
75
src/memory/api/search/constants.py
Normal file
75
src/memory/api/search/constants.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""
|
||||
Constants for search functionality.
|
||||
"""
|
||||
|
||||
# Reciprocal Rank Fusion constant (k parameter)
|
||||
# Higher values reduce the influence of top-ranked documents
|
||||
# 60 is the standard value from the original RRF paper
|
||||
RRF_K = 60
|
||||
|
||||
# Multiplier for internal search limit before fusion
|
||||
# We search for more candidates than requested, fuse scores, then return top N
|
||||
# This helps find results that rank well in one method but not the other
|
||||
CANDIDATE_MULTIPLIER = 5
|
||||
|
||||
# How many candidates to pass to reranker (multiplier of final limit)
|
||||
# Higher = more accurate but slower and more expensive
|
||||
RERANK_CANDIDATE_MULTIPLIER = 3
|
||||
|
||||
# Bonus for chunks containing query terms (added to RRF score)
|
||||
QUERY_TERM_BOOST = 0.005
|
||||
|
||||
# Bonus when query terms match the source title (stronger signal)
|
||||
TITLE_MATCH_BOOST = 0.01
|
||||
|
||||
# Bonus multiplier for popularity (applied as: score * (1 + POPULARITY_BOOST * (popularity - 1)))
|
||||
# This gives a small boost to popular items without dominating relevance
|
||||
POPULARITY_BOOST = 0.02
|
||||
|
||||
# Recency boost settings
|
||||
# Maximum bonus for brand new content (additive)
|
||||
RECENCY_BOOST_MAX = 0.005
|
||||
# Half-life in days: content loses half its recency boost every N days
|
||||
RECENCY_HALF_LIFE_DAYS = 90
|
||||
|
||||
# Common words to ignore when checking for query term presence
|
||||
STOPWORDS = frozenset({
|
||||
# Articles
|
||||
"a", "an", "the",
|
||||
# Be verbs
|
||||
"is", "are", "was", "were", "be", "been", "being",
|
||||
# Have verbs
|
||||
"have", "has", "had",
|
||||
# Do verbs
|
||||
"do", "does", "did",
|
||||
# Modal verbs
|
||||
"will", "would", "could", "should", "may", "might", "must", "shall", "can",
|
||||
"need", "dare", "ought", "used",
|
||||
# Prepositions
|
||||
"to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into",
|
||||
"through", "during", "before", "after", "above", "below", "between", "under",
|
||||
# Adverbs
|
||||
"again", "further", "then", "once", "here", "there", "when", "where", "why", "how",
|
||||
# Quantifiers
|
||||
"all", "each", "few", "more", "most", "other", "some", "such",
|
||||
# Negation
|
||||
"no", "nor", "not",
|
||||
# Other common words
|
||||
"only", "own", "same", "so", "than", "too", "very", "just",
|
||||
# Conjunctions
|
||||
"and", "but", "if", "or", "because", "until", "while", "although", "though",
|
||||
# Relative pronouns
|
||||
"what", "which", "who", "whom",
|
||||
# Demonstratives
|
||||
"this", "that", "these", "those",
|
||||
# Personal pronouns
|
||||
"i", "me", "my", "myself",
|
||||
"we", "our", "ours", "ourselves",
|
||||
"you", "your", "yours", "yourself", "yourselves",
|
||||
"he", "him", "his", "himself",
|
||||
"she", "her", "hers", "herself",
|
||||
"it", "its", "itself",
|
||||
"they", "them", "their", "theirs", "themselves",
|
||||
# Misc common words
|
||||
"about", "get", "got", "getting", "like", "also",
|
||||
})
|
||||
327
src/memory/api/search/query_analysis.py
Normal file
327
src/memory/api/search/query_analysis.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""
|
||||
LLM-based query analysis for intelligent search preprocessing.
|
||||
|
||||
Uses a fast LLM (Haiku) to analyze natural language queries and extract:
|
||||
- Modalities: content types to search (forum, book, comic, etc.)
|
||||
- Source hints: author names, domains, or specific sources
|
||||
- Cleaned query: the actual search terms with meta-language removed
|
||||
- Query variants: alternative phrasings to search
|
||||
|
||||
This runs in parallel with HyDE for maximum efficiency.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import func, text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.llms import create_provider, LLMSettings, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Threshold for listing specific sources (if fewer than N distinct domains, list them)
|
||||
MAX_DOMAINS_TO_LIST = 10
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModalityInfo:
|
||||
"""Information about a modality from the database."""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
domains: list[str] = field(default_factory=list)
|
||||
source_count: int | None = None # For modalities with parent entities (e.g., books)
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Build description including domains if there are few enough."""
|
||||
if self.source_count:
|
||||
base = f"{self.source_count:,} sources ({self.count:,} sections)"
|
||||
else:
|
||||
base = f"{self.count:,} items"
|
||||
|
||||
if self.domains:
|
||||
return f"{base} from: {', '.join(self.domains)}"
|
||||
return base
|
||||
|
||||
|
||||
# Cache for database-derived information
|
||||
_modality_cache: dict[str, ModalityInfo] = {}
|
||||
_cache_timestamp: float = 0
|
||||
_CACHE_TTL_SECONDS = 3600 # Refresh every hour
|
||||
|
||||
|
||||
def _get_tables_with_url_column(db) -> list[str]:
|
||||
"""Query database schema to find tables that have a 'url' column."""
|
||||
result = db.execute(text("""
|
||||
SELECT table_name FROM information_schema.columns
|
||||
WHERE column_name = 'url' AND table_schema = 'public'
|
||||
"""))
|
||||
return [row[0] for row in result]
|
||||
|
||||
|
||||
def _get_modality_domains(db) -> dict[str, list[str]]:
|
||||
"""Get domains for each modality that has URL data."""
|
||||
tables = _get_tables_with_url_column(db)
|
||||
if not tables:
|
||||
return {}
|
||||
|
||||
# Build a UNION query to get modality + url from all URL-containing tables
|
||||
union_parts = []
|
||||
for table in tables:
|
||||
union_parts.append(
|
||||
f"SELECT s.modality, t.url FROM source_item s "
|
||||
f"JOIN {table} t ON s.id = t.id WHERE t.url IS NOT NULL"
|
||||
)
|
||||
|
||||
if not union_parts:
|
||||
return {}
|
||||
|
||||
query = " UNION ALL ".join(union_parts)
|
||||
|
||||
try:
|
||||
result = db.execute(text(query))
|
||||
rows = list(result)
|
||||
except SQLAlchemyError as e:
|
||||
logger.debug(f"Database error getting modality URLs: {e}")
|
||||
return {}
|
||||
|
||||
# Group URLs by modality and extract domains
|
||||
modality_domains: dict[str, set[str]] = {}
|
||||
for modality, url in rows:
|
||||
if not modality or not url:
|
||||
continue
|
||||
try:
|
||||
domain = urlparse(url).netloc
|
||||
if domain:
|
||||
domain = domain.replace("www.", "")
|
||||
if modality not in modality_domains:
|
||||
modality_domains[modality] = set()
|
||||
modality_domains[modality].add(domain)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Only return domains for modalities with few enough to list
|
||||
return {
|
||||
modality: sorted(domains)
|
||||
for modality, domains in modality_domains.items()
|
||||
if len(domains) <= MAX_DOMAINS_TO_LIST
|
||||
}
|
||||
|
||||
|
||||
def _get_source_counts(db) -> dict[str, int]:
|
||||
"""Get distinct source counts for modalities with parent entities."""
|
||||
try:
|
||||
# Books: count distinct book_id from book_section
|
||||
result = db.execute(text(
|
||||
"SELECT COUNT(DISTINCT book_id) FROM book_section"
|
||||
))
|
||||
book_count = result.scalar() or 0
|
||||
return {"book": book_count}
|
||||
except SQLAlchemyError:
|
||||
return {}
|
||||
|
||||
|
||||
def _refresh_modality_cache() -> None:
|
||||
"""Query database to find modalities with actual content."""
|
||||
global _modality_cache, _cache_timestamp
|
||||
|
||||
try:
|
||||
with make_session() as db:
|
||||
# Get modality counts
|
||||
results = (
|
||||
db.query(SourceItem.modality, func.count(SourceItem.id))
|
||||
.group_by(SourceItem.modality)
|
||||
.order_by(func.count(SourceItem.id).desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get domains for modalities with URLs (single query)
|
||||
modality_domains = _get_modality_domains(db)
|
||||
|
||||
# Get source counts for modalities with parent entities
|
||||
source_counts = _get_source_counts(db)
|
||||
|
||||
_modality_cache = {}
|
||||
for modality, count in results:
|
||||
if modality and count > 0:
|
||||
_modality_cache[modality] = ModalityInfo(
|
||||
name=modality,
|
||||
count=count,
|
||||
domains=modality_domains.get(modality, []),
|
||||
source_count=source_counts.get(modality),
|
||||
)
|
||||
|
||||
_cache_timestamp = time.time()
|
||||
logger.debug(f"Refreshed modality cache: {list(_modality_cache.keys())}")
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.warning(f"Database error refreshing modality cache: {e}")
|
||||
|
||||
|
||||
def _get_available_modalities() -> dict[str, ModalityInfo]:
|
||||
"""Get modalities with content, refreshing cache if needed."""
|
||||
global _cache_timestamp
|
||||
|
||||
if time.time() - _cache_timestamp > _CACHE_TTL_SECONDS or not _modality_cache:
|
||||
_refresh_modality_cache()
|
||||
|
||||
return _modality_cache
|
||||
|
||||
|
||||
def _build_prompt() -> str:
|
||||
"""Build the query analysis prompt with actual available modalities."""
|
||||
modalities = _get_available_modalities()
|
||||
|
||||
if not modalities:
|
||||
modality_section = " (no content indexed yet)"
|
||||
else:
|
||||
lines = []
|
||||
for info in modalities.values():
|
||||
lines.append(f" - {info.name}: {info.description}")
|
||||
modality_section = "\n".join(lines)
|
||||
|
||||
modality_names = list(modalities.keys()) if modalities else []
|
||||
|
||||
return (
|
||||
textwrap.dedent("""
|
||||
Analyze this search query and extract structured information.
|
||||
|
||||
The user is searching a personal knowledge base containing:
|
||||
{modality_section}
|
||||
|
||||
Return a JSON object:
|
||||
{{
|
||||
"modalities": [], // From: {modality_names} (empty = search all)
|
||||
"sources": [], // Specific sources/authors mentioned
|
||||
"cleaned_query": "", // Query with meta-language removed
|
||||
"query_variants": [] // 1-3 alternative phrasings
|
||||
}}
|
||||
|
||||
Guidelines:
|
||||
- "on lesswrong" -> forum, "comic about" -> comic, etc.
|
||||
- Remove "there was something about", "I remember reading", etc.
|
||||
- Generate useful query variants
|
||||
|
||||
Return ONLY valid JSON.
|
||||
""")
|
||||
.strip()
|
||||
.format(
|
||||
modality_section=modality_section,
|
||||
modality_names=modality_names,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryAnalysis:
|
||||
"""Result of LLM-based query analysis."""
|
||||
|
||||
modalities: set[str] = field(default_factory=set)
|
||||
sources: list[str] = field(default_factory=list)
|
||||
cleaned_query: str = ""
|
||||
query_variants: list[str] = field(default_factory=list)
|
||||
success: bool = False
|
||||
|
||||
|
||||
# Cache for recent analyses
|
||||
_analysis_cache: dict[str, QueryAnalysis] = {}
|
||||
_CACHE_MAX_SIZE = 100
|
||||
|
||||
|
||||
async def analyze_query(
|
||||
query: str,
|
||||
model: Optional[str] = None,
|
||||
timeout: float = 3.0,
|
||||
) -> QueryAnalysis:
|
||||
"""
|
||||
Analyze a search query using an LLM to extract search parameters.
|
||||
|
||||
Args:
|
||||
query: The user's natural language search query
|
||||
model: LLM model to use (defaults to SUMMARIZER_MODEL, ideally Haiku)
|
||||
timeout: Maximum time to wait for LLM response
|
||||
|
||||
Returns:
|
||||
QueryAnalysis with extracted modalities, sources, cleaned query, and variants
|
||||
"""
|
||||
# Check cache first
|
||||
cache_key = query.lower().strip()
|
||||
if cache_key in _analysis_cache:
|
||||
logger.debug(f"Query analysis cache hit for: {query[:50]}...")
|
||||
return _analysis_cache[cache_key]
|
||||
|
||||
result = QueryAnalysis(cleaned_query=query)
|
||||
|
||||
try:
|
||||
provider = create_provider(model=model or settings.SUMMARIZER_MODEL)
|
||||
|
||||
messages = [Message.user(text=f"Query: {query}")]
|
||||
|
||||
llm_settings = LLMSettings(
|
||||
temperature=0.1, # Low temperature for consistent structured output
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
provider.agenerate(
|
||||
messages=messages,
|
||||
system_prompt=_build_prompt(),
|
||||
settings=llm_settings,
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if response:
|
||||
# Parse JSON response
|
||||
response = response.strip()
|
||||
# Handle markdown code blocks
|
||||
if response.startswith("```"):
|
||||
response = response.split("```")[1]
|
||||
if response.startswith("json"):
|
||||
response = response[4:]
|
||||
response = response.strip()
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
|
||||
result.modalities = set(data.get("modalities", []))
|
||||
result.sources = data.get("sources", [])
|
||||
result.cleaned_query = data.get("cleaned_query", query)
|
||||
result.query_variants = data.get("query_variants", [])
|
||||
result.success = True
|
||||
|
||||
logger.debug(
|
||||
f"Query analysis: '{query[:40]}...' -> "
|
||||
f"modalities={result.modalities}, "
|
||||
f"cleaned='{result.cleaned_query[:30]}...'"
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse query analysis JSON: {e}")
|
||||
result.cleaned_query = query
|
||||
|
||||
# Cache the result
|
||||
if len(_analysis_cache) >= _CACHE_MAX_SIZE:
|
||||
keys_to_remove = list(_analysis_cache.keys())[: _CACHE_MAX_SIZE // 2]
|
||||
for key in keys_to_remove:
|
||||
del _analysis_cache[key]
|
||||
_analysis_cache[cache_key] = result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Query analysis timed out for: {query[:50]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"Query analysis failed: {e}")
|
||||
|
||||
return result
|
||||
@ -5,10 +5,8 @@ Search endpoints for the knowledge base API.
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import load_only
|
||||
from memory.common import extract, settings
|
||||
from memory.common.db.connection import make_session
|
||||
@ -16,6 +14,17 @@ from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.collections import ALL_COLLECTIONS
|
||||
from memory.api.search.embeddings import search_chunks_embeddings
|
||||
from memory.api.search import scorer
|
||||
from memory.api.search.constants import (
|
||||
RRF_K,
|
||||
CANDIDATE_MULTIPLIER,
|
||||
RERANK_CANDIDATE_MULTIPLIER,
|
||||
QUERY_TERM_BOOST,
|
||||
TITLE_MATCH_BOOST,
|
||||
POPULARITY_BOOST,
|
||||
RECENCY_BOOST_MAX,
|
||||
RECENCY_HALF_LIFE_DAYS,
|
||||
STOPWORDS,
|
||||
)
|
||||
|
||||
if settings.ENABLE_BM25_SEARCH:
|
||||
from memory.api.search.bm25 import search_bm25_chunks
|
||||
@ -26,6 +35,7 @@ if settings.ENABLE_HYDE_EXPANSION:
|
||||
if settings.ENABLE_RERANKING:
|
||||
from memory.api.search.rerank import rerank_chunks
|
||||
|
||||
from memory.api.search.query_analysis import analyze_query, QueryAnalysis
|
||||
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||
|
||||
# Default config for when none is provided
|
||||
@ -33,202 +43,6 @@ _DEFAULT_CONFIG = SearchConfig()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reciprocal Rank Fusion constant (k parameter)
|
||||
# Higher values reduce the influence of top-ranked documents
|
||||
# 60 is the standard value from the original RRF paper
|
||||
RRF_K = 60
|
||||
|
||||
# Multiplier for internal search limit before fusion
|
||||
# We search for more candidates than requested, fuse scores, then return top N
|
||||
# This helps find results that rank well in one method but not the other
|
||||
CANDIDATE_MULTIPLIER = 5
|
||||
|
||||
# How many candidates to pass to reranker (multiplier of final limit)
|
||||
# Higher = more accurate but slower and more expensive
|
||||
RERANK_CANDIDATE_MULTIPLIER = 3
|
||||
|
||||
# Bonus for chunks containing query terms (added to RRF score)
|
||||
QUERY_TERM_BOOST = 0.005
|
||||
|
||||
# Bonus when query terms match the source title (stronger signal)
|
||||
TITLE_MATCH_BOOST = 0.01
|
||||
|
||||
# Bonus multiplier for popularity (applied as: score * (1 + POPULARITY_BOOST * (popularity - 1)))
|
||||
# This gives a small boost to popular items without dominating relevance
|
||||
POPULARITY_BOOST = 0.02
|
||||
|
||||
# Recency boost settings
|
||||
# Maximum bonus for brand new content (additive)
|
||||
RECENCY_BOOST_MAX = 0.005
|
||||
# Half-life in days: content loses half its recency boost every N days
|
||||
RECENCY_HALF_LIFE_DAYS = 90
|
||||
|
||||
# Query expansion: map abbreviations/acronyms to full forms
|
||||
# These help match when users search for "ML" but documents say "machine learning"
|
||||
QUERY_EXPANSIONS: dict[str, list[str]] = {
|
||||
# AI/ML abbreviations
|
||||
"ai": ["artificial intelligence"],
|
||||
"ml": ["machine learning"],
|
||||
"dl": ["deep learning"],
|
||||
"nlp": ["natural language processing"],
|
||||
"cv": ["computer vision"],
|
||||
"rl": ["reinforcement learning"],
|
||||
"llm": ["large language model", "language model"],
|
||||
"gpt": ["generative pretrained transformer", "language model"],
|
||||
"nn": ["neural network"],
|
||||
"cnn": ["convolutional neural network"],
|
||||
"rnn": ["recurrent neural network"],
|
||||
"lstm": ["long short term memory"],
|
||||
"gan": ["generative adversarial network"],
|
||||
"rag": ["retrieval augmented generation"],
|
||||
# Rationality/EA terms
|
||||
"ea": ["effective altruism"],
|
||||
"lw": ["lesswrong", "less wrong"],
|
||||
"gwwc": ["giving what we can"],
|
||||
"agi": ["artificial general intelligence"],
|
||||
"asi": ["artificial superintelligence"],
|
||||
"fai": ["friendly ai", "ai alignment"],
|
||||
"x-risk": ["existential risk"],
|
||||
"xrisk": ["existential risk"],
|
||||
"p(doom)": ["probability of doom", "ai risk"],
|
||||
# Reverse mappings (full forms -> abbreviations)
|
||||
"artificial intelligence": ["ai"],
|
||||
"machine learning": ["ml"],
|
||||
"deep learning": ["dl"],
|
||||
"natural language processing": ["nlp"],
|
||||
"computer vision": ["cv"],
|
||||
"reinforcement learning": ["rl"],
|
||||
"neural network": ["nn"],
|
||||
"effective altruism": ["ea"],
|
||||
"existential risk": ["x-risk", "xrisk"],
|
||||
# Family relationships (bidirectional)
|
||||
"father": ["son", "daughter", "child", "parent", "dad"],
|
||||
"mother": ["son", "daughter", "child", "parent", "mom"],
|
||||
"parent": ["child", "son", "daughter", "father", "mother"],
|
||||
"son": ["father", "parent", "child"],
|
||||
"daughter": ["mother", "parent", "child"],
|
||||
"child": ["parent", "father", "mother"],
|
||||
"dad": ["father", "son", "daughter", "child"],
|
||||
"mom": ["mother", "son", "daughter", "child"],
|
||||
}
|
||||
|
||||
# Modality detection patterns: map query phrases to collection names
|
||||
# Each entry is (pattern, modalities, strip_pattern)
|
||||
# - pattern: regex to match in query
|
||||
# - modalities: set of collection names to filter to
|
||||
# - strip_pattern: whether to remove the matched text from query
|
||||
MODALITY_PATTERNS: list[tuple[str, set[str], bool]] = [
|
||||
# Comics
|
||||
(r"\b(comic|comics|webcomic|webcomics)\b", {"comic"}, True),
|
||||
# Forum posts (LessWrong, EA Forum, etc.)
|
||||
(r"\b(on\s+)?(lesswrong|lw|less\s+wrong)\b", {"forum"}, True),
|
||||
(r"\b(on\s+)?(ea\s+forum|effective\s+altruism\s+forum)\b", {"forum"}, True),
|
||||
(r"\b(on\s+)?(alignment\s+forum|af)\b", {"forum"}, True),
|
||||
(r"\b(forum\s+post|lw\s+post|post\s+on)\b", {"forum"}, True),
|
||||
# Books
|
||||
(r"\b(in\s+a\s+book|in\s+the\s+book|book|chapter)\b", {"book"}, True),
|
||||
# Blog posts / articles
|
||||
(r"\b(blog\s+post|blog|article)\b", {"blog"}, True),
|
||||
# Email
|
||||
(r"\b(email|e-mail|mail)\b", {"mail"}, True),
|
||||
# Photos / images
|
||||
(r"\b(photo|photograph|picture|image)\b", {"photo"}, True),
|
||||
# Documents
|
||||
(r"\b(document|pdf|doc)\b", {"doc"}, True),
|
||||
# Chat / messages
|
||||
(r"\b(chat|message|discord|slack)\b", {"chat"}, True),
|
||||
# Git
|
||||
(r"\b(commit|git|pull\s+request|pr)\b", {"git"}, True),
|
||||
]
|
||||
|
||||
# Meta-language patterns to strip (these don't indicate modality, just noise)
|
||||
META_LANGUAGE_PATTERNS: list[str] = [
|
||||
r"\bthere\s+was\s+(something|some|some\s+\w+|an?\s+\w+)\s+(about|on)\b",
|
||||
r"\bi\s+remember\s+(reading|seeing|there\s+being)\s*(an?\s+)?",
|
||||
r"\bi\s+(read|saw|found)\s+(something|an?\s+\w+)\s+about\b",
|
||||
r"\bsomething\s+about\b",
|
||||
r"\bsome\s+about\b",
|
||||
r"\bthis\s+whole\s+\w+\s+thing\b",
|
||||
r"\bthat\s+\w+\s+thing\b",
|
||||
r"\bthat\s+about\b", # Clean up leftover "that about"
|
||||
r"\ba\s+about\b", # Clean up leftover "a about"
|
||||
r"\bthe\s+about\b", # Clean up leftover "the about"
|
||||
r"\bthere\s+was\s+some\s+about\b", # Clean up leftover
|
||||
]
|
||||
|
||||
|
||||
def detect_modality_hints(query: str) -> tuple[str, set[str]]:
|
||||
"""
|
||||
Detect content type hints in query and extract modalities.
|
||||
|
||||
Returns:
|
||||
(cleaned_query, detected_modalities)
|
||||
- cleaned_query: query with modality indicators and meta-language removed
|
||||
- detected_modalities: set of collection names detected from query
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
detected: set[str] = set()
|
||||
cleaned = query
|
||||
|
||||
# First, detect and strip modality patterns
|
||||
for pattern, modalities, strip in MODALITY_PATTERNS:
|
||||
if re.search(pattern, query_lower, re.IGNORECASE):
|
||||
detected.update(modalities)
|
||||
if strip:
|
||||
cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE)
|
||||
|
||||
# Strip meta-language patterns (regardless of modality detection)
|
||||
for pattern in META_LANGUAGE_PATTERNS:
|
||||
cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE)
|
||||
|
||||
# Clean up whitespace
|
||||
cleaned = " ".join(cleaned.split())
|
||||
|
||||
return cleaned, detected
|
||||
|
||||
|
||||
def expand_query(query: str) -> str:
|
||||
"""
|
||||
Expand query with synonyms and abbreviations.
|
||||
|
||||
This helps match documents that use different terminology for the same concept.
|
||||
For example, "ML algorithms" -> "ML machine learning algorithms"
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
expansions = []
|
||||
|
||||
for term, synonyms in QUERY_EXPANSIONS.items():
|
||||
# Check if term appears as a whole word in the query
|
||||
# Use word boundaries to avoid matching partial words
|
||||
pattern = r'\b' + re.escape(term) + r'\b'
|
||||
if re.search(pattern, query_lower):
|
||||
expansions.extend(synonyms)
|
||||
|
||||
if expansions:
|
||||
# Add expansions to the original query
|
||||
return query + " " + " ".join(expansions)
|
||||
return query
|
||||
|
||||
|
||||
# Common words to ignore when checking for query term presence
|
||||
STOPWORDS = {
|
||||
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
|
||||
"have", "has", "had", "do", "does", "did", "will", "would", "could",
|
||||
"should", "may", "might", "must", "shall", "can", "need", "dare",
|
||||
"ought", "used", "to", "of", "in", "for", "on", "with", "at", "by",
|
||||
"from", "as", "into", "through", "during", "before", "after", "above",
|
||||
"below", "between", "under", "again", "further", "then", "once", "here",
|
||||
"there", "when", "where", "why", "how", "all", "each", "few", "more",
|
||||
"most", "other", "some", "such", "no", "nor", "not", "only", "own",
|
||||
"same", "so", "than", "too", "very", "just", "and", "but", "if", "or",
|
||||
"because", "until", "while", "although", "though", "after", "before",
|
||||
"what", "which", "who", "whom", "this", "that", "these", "those", "i",
|
||||
"me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your",
|
||||
"yours", "yourself", "yourselves", "he", "him", "his", "himself", "she",
|
||||
"her", "hers", "herself", "it", "its", "itself", "they", "them", "their",
|
||||
"theirs", "themselves", "about", "get", "got", "getting", "like", "also",
|
||||
}
|
||||
|
||||
|
||||
def extract_query_terms(query: str) -> set[str]:
|
||||
"""Extract meaningful terms from query, filtering stopwords."""
|
||||
@ -270,7 +84,9 @@ def deduplicate_by_source(chunks: list[Chunk]) -> list[Chunk]:
|
||||
source_id = chunk.source_id
|
||||
if source_id not in best_by_source:
|
||||
best_by_source[source_id] = chunk
|
||||
elif (chunk.relevance_score or 0) > (best_by_source[source_id].relevance_score or 0):
|
||||
elif (chunk.relevance_score or 0) > (
|
||||
best_by_source[source_id].relevance_score or 0
|
||||
):
|
||||
best_by_source[source_id] = chunk
|
||||
return list(best_by_source.values())
|
||||
|
||||
@ -294,9 +110,7 @@ def apply_source_boosts(
|
||||
|
||||
# Single query to fetch all source metadata
|
||||
with make_session() as db:
|
||||
sources = db.query(SourceItem).filter(
|
||||
SourceItem.id.in_(source_ids)
|
||||
).all()
|
||||
sources = db.query(SourceItem).filter(SourceItem.id.in_(source_ids)).all()
|
||||
source_map = {
|
||||
s.id: {
|
||||
"title": (getattr(s, "title", None) or "").lower(),
|
||||
@ -369,7 +183,9 @@ def fuse_scores_rrf(
|
||||
Dict mapping chunk IDs to RRF scores
|
||||
"""
|
||||
# Convert scores to ranks (1-indexed)
|
||||
emb_ranked = sorted(embedding_scores.keys(), key=lambda x: embedding_scores[x], reverse=True)
|
||||
emb_ranked = sorted(
|
||||
embedding_scores.keys(), key=lambda x: embedding_scores[x], reverse=True
|
||||
)
|
||||
bm25_ranked = sorted(bm25_scores.keys(), key=lambda x: bm25_scores[x], reverse=True)
|
||||
|
||||
emb_ranks = {chunk_id: rank + 1 for rank, chunk_id in enumerate(emb_ranked)}
|
||||
@ -393,84 +209,131 @@ def fuse_scores_rrf(
|
||||
return fused
|
||||
|
||||
|
||||
async def search_chunks(
|
||||
async def _run_llm_analysis(
|
||||
query_text: str,
|
||||
use_query_analysis: bool,
|
||||
use_hyde: bool,
|
||||
) -> tuple[QueryAnalysis | None, str | None]:
|
||||
"""
|
||||
Run LLM-based query analysis and/or HyDE expansion in parallel.
|
||||
|
||||
Returns:
|
||||
(analysis_result, hyde_doc) tuple
|
||||
"""
|
||||
analysis_result: QueryAnalysis | None = None
|
||||
hyde_doc: str | None = None
|
||||
|
||||
if not (use_query_analysis or use_hyde):
|
||||
return analysis_result, hyde_doc
|
||||
|
||||
tasks = []
|
||||
|
||||
if use_query_analysis:
|
||||
tasks.append(("analysis", analyze_query(query_text, timeout=3.0)))
|
||||
|
||||
if use_hyde and len(query_text.split()) >= 4:
|
||||
tasks.append(
|
||||
("hyde", expand_query_hyde(query_text, timeout=settings.HYDE_TIMEOUT))
|
||||
)
|
||||
|
||||
if not tasks:
|
||||
return analysis_result, hyde_doc
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[task for _, task in tasks], return_exceptions=True
|
||||
)
|
||||
|
||||
for i, (name, _) in enumerate(tasks):
|
||||
result = results[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"{name} failed: {result}")
|
||||
continue
|
||||
|
||||
if name == "analysis" and result:
|
||||
analysis_result = result
|
||||
elif name == "hyde" and result:
|
||||
hyde_doc = result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Parallel LLM calls failed: {e}")
|
||||
|
||||
return analysis_result, hyde_doc
|
||||
|
||||
|
||||
def _apply_query_analysis(
|
||||
analysis_result: QueryAnalysis,
|
||||
query_text: str,
|
||||
data: list[extract.DataChunk],
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = {},
|
||||
timeout: int = 2,
|
||||
config: SearchConfig = _DEFAULT_CONFIG,
|
||||
) -> list[Chunk]:
|
||||
modalities: set[str],
|
||||
) -> tuple[str, list[extract.DataChunk], set[str], list[str]]:
|
||||
"""
|
||||
Search chunks using embedding similarity and optionally BM25.
|
||||
Apply query analysis results to modify query, data, and modalities.
|
||||
|
||||
Combines results using weighted score fusion, giving bonus to documents
|
||||
that match both semantically and lexically.
|
||||
|
||||
If HyDE is enabled, also generates a hypothetical document from the query
|
||||
and includes it in the embedding search for better semantic matching.
|
||||
|
||||
Enhancement flags in config override global settings when set:
|
||||
- useBm25: Enable BM25 lexical search
|
||||
- useHyde: Enable HyDE query expansion
|
||||
- useReranking: Enable cross-encoder reranking
|
||||
- useQueryExpansion: Enable synonym/abbreviation expansion
|
||||
- useModalityDetection: Detect content type hints from query
|
||||
Returns:
|
||||
(updated_query_text, updated_data, updated_modalities, query_variants)
|
||||
"""
|
||||
# Resolve enhancement flags: config overrides global settings
|
||||
use_bm25 = config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
|
||||
use_hyde = config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
|
||||
use_reranking = config.useReranking if config.useReranking is not None else settings.ENABLE_RERANKING
|
||||
use_query_expansion = config.useQueryExpansion if config.useQueryExpansion is not None else True
|
||||
use_modality_detection = config.useModalityDetection if config.useModalityDetection is not None else False
|
||||
query_variants: list[str] = []
|
||||
|
||||
# Search for more candidates than requested, fuse scores, then return top N
|
||||
# This helps find results that rank well in one method but not the other
|
||||
internal_limit = limit * CANDIDATE_MULTIPLIER
|
||||
if not (analysis_result and analysis_result.success):
|
||||
return query_text, data, modalities, query_variants
|
||||
|
||||
# Extract query text
|
||||
query_text = " ".join(
|
||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||
)
|
||||
# Use detected modalities if any
|
||||
if analysis_result.modalities:
|
||||
modalities = analysis_result.modalities
|
||||
logger.debug(f"Query analysis modalities: {modalities}")
|
||||
|
||||
# Detect modality hints and clean query if enabled
|
||||
if use_modality_detection:
|
||||
cleaned_query, detected_modalities = detect_modality_hints(query_text)
|
||||
if detected_modalities:
|
||||
# Override passed modalities with detected ones
|
||||
modalities = detected_modalities
|
||||
logger.debug(f"Modality detection: '{query_text[:50]}...' -> modalities={detected_modalities}")
|
||||
if cleaned_query != query_text:
|
||||
logger.debug(f"Query cleaning: '{query_text[:50]}...' -> '{cleaned_query[:50]}...'")
|
||||
query_text = cleaned_query
|
||||
# Update data with cleaned query for downstream processing
|
||||
data = [extract.DataChunk(data=[cleaned_query])]
|
||||
# Use cleaned query
|
||||
if analysis_result.cleaned_query and analysis_result.cleaned_query != query_text:
|
||||
logger.debug(
|
||||
f"Query analysis cleaning: '{query_text[:40]}...' -> '{analysis_result.cleaned_query[:40]}...'"
|
||||
)
|
||||
query_text = analysis_result.cleaned_query
|
||||
data = [extract.DataChunk(data=[analysis_result.cleaned_query])]
|
||||
|
||||
if use_query_expansion:
|
||||
expanded_query = expand_query(query_text)
|
||||
# If query was expanded, use expanded version for search
|
||||
if expanded_query != query_text:
|
||||
logger.debug(f"Query expansion: '{query_text}' -> '{expanded_query}'")
|
||||
search_data = [extract.DataChunk(data=[expanded_query])]
|
||||
else:
|
||||
search_data = list(data) # Copy to avoid modifying original
|
||||
else:
|
||||
search_data = list(data)
|
||||
# Collect query variants
|
||||
query_variants.extend(analysis_result.query_variants)
|
||||
|
||||
# Apply HyDE expansion if enabled
|
||||
if use_hyde:
|
||||
# Only expand queries with 4+ words (short queries are usually specific enough)
|
||||
if len(query_text.split()) >= 4:
|
||||
try:
|
||||
hyde_doc = await expand_query_hyde(
|
||||
query_text, timeout=settings.HYDE_TIMEOUT
|
||||
)
|
||||
if hyde_doc:
|
||||
logger.debug(f"HyDE expansion: '{query_text[:30]}...' -> '{hyde_doc[:50]}...'")
|
||||
search_data.append(extract.DataChunk(data=[hyde_doc]))
|
||||
except Exception as e:
|
||||
logger.warning(f"HyDE expansion failed, using original query: {e}")
|
||||
return query_text, data, modalities, query_variants
|
||||
|
||||
|
||||
def _build_search_data(
|
||||
data: list[extract.DataChunk],
|
||||
hyde_doc: str | None,
|
||||
query_variants: list[str],
|
||||
query_text: str,
|
||||
) -> list[extract.DataChunk]:
|
||||
"""
|
||||
Build the list of data chunks to search with.
|
||||
|
||||
Includes original query, HyDE expansion, and query variants.
|
||||
"""
|
||||
search_data = list(data)
|
||||
|
||||
# Add HyDE expansion if we got one
|
||||
if hyde_doc:
|
||||
logger.debug(f"HyDE expansion: '{query_text[:30]}...' -> '{hyde_doc[:50]}...'")
|
||||
search_data.append(extract.DataChunk(data=[hyde_doc]))
|
||||
|
||||
# Add query variants from analysis (limit to 3)
|
||||
for variant in query_variants[:3]:
|
||||
search_data.append(extract.DataChunk(data=[variant]))
|
||||
|
||||
return search_data
|
||||
|
||||
|
||||
async def _run_searches(
|
||||
search_data: list[extract.DataChunk],
|
||||
data: list[extract.DataChunk],
|
||||
modalities: set[str],
|
||||
internal_limit: int,
|
||||
filters: SearchFilters,
|
||||
timeout: int,
|
||||
use_bm25: bool,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Run embedding and optionally BM25 searches, returning fused scores.
|
||||
"""
|
||||
# Run embedding search
|
||||
embedding_scores = await search_chunks_embeddings(
|
||||
search_data, modalities, internal_limit, filters, timeout
|
||||
@ -487,14 +350,25 @@ async def search_chunks(
|
||||
logger.warning("BM25 search timed out, using embedding results only")
|
||||
|
||||
# Fuse scores from both methods using Reciprocal Rank Fusion
|
||||
fused_scores = fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||
return fuse_scores_rrf(embedding_scores, bm25_scores)
|
||||
|
||||
|
||||
def _fetch_chunks(
|
||||
fused_scores: dict[str, float],
|
||||
limit: int,
|
||||
use_reranking: bool,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Fetch chunk objects from database and set their relevance scores.
|
||||
"""
|
||||
if not fused_scores:
|
||||
return []
|
||||
|
||||
# Sort by score and take top results
|
||||
# If reranking is enabled, fetch more candidates for the reranker to work with
|
||||
sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
|
||||
sorted_ids = sorted(
|
||||
fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True
|
||||
)
|
||||
if use_reranking:
|
||||
fetch_limit = limit * RERANK_CANDIDATE_MULTIPLIER
|
||||
else:
|
||||
@ -522,29 +396,125 @@ async def search_chunks(
|
||||
|
||||
db.expunge_all()
|
||||
|
||||
# Extract query text for boosting and reranking
|
||||
return chunks
|
||||
|
||||
|
||||
def _apply_boosts(
|
||||
chunks: list[Chunk],
|
||||
data: list[extract.DataChunk],
|
||||
) -> None:
|
||||
"""
|
||||
Apply query term, title, popularity, and recency boosts to chunks.
|
||||
"""
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
# Extract query text for boosting
|
||||
query_text = " ".join(
|
||||
c for chunk in data for c in chunk.data if isinstance(c, str)
|
||||
)
|
||||
|
||||
# Apply query term presence boost
|
||||
if chunks and query_text.strip():
|
||||
if query_text.strip():
|
||||
query_terms = extract_query_terms(query_text)
|
||||
apply_query_term_boost(chunks, query_terms)
|
||||
# Apply title + popularity boosts (single DB query)
|
||||
apply_source_boosts(chunks, query_terms)
|
||||
elif chunks:
|
||||
else:
|
||||
# No query terms, just apply popularity boost
|
||||
apply_source_boosts(chunks, set())
|
||||
|
||||
# Rerank using cross-encoder for better precision
|
||||
if use_reranking and chunks and query_text.strip():
|
||||
try:
|
||||
chunks = await rerank_chunks(
|
||||
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Reranking failed, using RRF order: {e}")
|
||||
|
||||
async def _apply_reranking(
|
||||
chunks: list[Chunk],
|
||||
query_text: str,
|
||||
limit: int,
|
||||
use_reranking: bool,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Apply cross-encoder reranking if enabled.
|
||||
"""
|
||||
if not (use_reranking and chunks and query_text.strip()):
|
||||
return chunks
|
||||
|
||||
try:
|
||||
return await rerank_chunks(
|
||||
query_text, chunks, model=settings.RERANK_MODEL, top_k=limit
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Reranking failed, using RRF order: {e}")
|
||||
return chunks
|
||||
|
||||
|
||||
async def search_chunks(
|
||||
data: list[extract.DataChunk],
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = {},
|
||||
timeout: int = 2,
|
||||
config: SearchConfig = _DEFAULT_CONFIG,
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Search chunks using embedding similarity and optionally BM25.
|
||||
|
||||
Combines results using weighted score fusion, giving bonus to documents
|
||||
that match both semantically and lexically.
|
||||
|
||||
If HyDE is enabled, also generates a hypothetical document from the query
|
||||
and includes it in the embedding search for better semantic matching.
|
||||
|
||||
Enhancement flags in config override global settings when set:
|
||||
- useBm25: Enable BM25 lexical search
|
||||
- useHyde: Enable HyDE query expansion
|
||||
- useReranking: Enable cross-encoder reranking
|
||||
- useQueryAnalysis: LLM-based query analysis (extracts modalities, cleans query, generates variants)
|
||||
"""
|
||||
# Resolve enhancement flags: config overrides global settings
|
||||
use_bm25 = (
|
||||
config.useBm25 if config.useBm25 is not None else settings.ENABLE_BM25_SEARCH
|
||||
)
|
||||
use_hyde = (
|
||||
config.useHyde if config.useHyde is not None else settings.ENABLE_HYDE_EXPANSION
|
||||
)
|
||||
use_reranking = (
|
||||
config.useReranking
|
||||
if config.useReranking is not None
|
||||
else settings.ENABLE_RERANKING
|
||||
)
|
||||
use_query_analysis = (
|
||||
config.useQueryAnalysis if config.useQueryAnalysis is not None else False
|
||||
)
|
||||
|
||||
internal_limit = limit * CANDIDATE_MULTIPLIER
|
||||
|
||||
# Extract query text
|
||||
query_text = " ".join(c for chunk in data for c in chunk.data if isinstance(c, str))
|
||||
|
||||
# Run LLM-based operations in parallel (query analysis + HyDE)
|
||||
analysis_result, hyde_doc = await _run_llm_analysis(
|
||||
query_text, use_query_analysis, use_hyde
|
||||
)
|
||||
|
||||
# Apply query analysis results
|
||||
query_text, data, modalities, query_variants = _apply_query_analysis(
|
||||
analysis_result, query_text, data, modalities
|
||||
)
|
||||
|
||||
# Build search data with HyDE and variants
|
||||
search_data = _build_search_data(data, hyde_doc, query_variants, query_text)
|
||||
|
||||
# Run searches and fuse scores
|
||||
fused_scores = await _run_searches(
|
||||
search_data, data, modalities, internal_limit, filters, timeout, use_bm25
|
||||
)
|
||||
|
||||
# Fetch chunks from database
|
||||
chunks = _fetch_chunks(fused_scores, limit, use_reranking)
|
||||
|
||||
# Apply various boosts
|
||||
_apply_boosts(chunks, data)
|
||||
|
||||
# Apply reranking if enabled
|
||||
chunks = await _apply_reranking(chunks, query_text, limit, use_reranking)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@ -87,8 +87,7 @@ class SearchConfig(BaseModel):
|
||||
useBm25: Optional[bool] = None
|
||||
useHyde: Optional[bool] = None
|
||||
useReranking: Optional[bool] = None
|
||||
useQueryExpansion: Optional[bool] = None
|
||||
useModalityDetection: Optional[bool] = None
|
||||
useQueryAnalysis: Optional[bool] = None # LLM-based query analysis (Haiku)
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
# Enforce reasonable limits
|
||||
|
||||
428
tests/memory/api/search/test_query_analysis.py
Normal file
428
tests/memory/api/search/test_query_analysis.py
Normal file
@ -0,0 +1,428 @@
|
||||
"""Tests for query_analysis module."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, Mock, AsyncMock
|
||||
|
||||
from memory.api.search.query_analysis import (
|
||||
ModalityInfo,
|
||||
QueryAnalysis,
|
||||
_build_prompt,
|
||||
_get_available_modalities,
|
||||
analyze_query,
|
||||
MAX_DOMAINS_TO_LIST,
|
||||
)
|
||||
|
||||
|
||||
class TestModalityInfo:
|
||||
"""Tests for ModalityInfo dataclass."""
|
||||
|
||||
def test_description_items_only(self):
|
||||
"""Basic description with just count."""
|
||||
info = ModalityInfo(name="forum", count=1000)
|
||||
assert info.description == "1,000 items"
|
||||
|
||||
def test_description_with_domains(self):
|
||||
"""Description includes domains when present."""
|
||||
info = ModalityInfo(
|
||||
name="forum", count=1000, domains=["lesswrong.com", "example.com"]
|
||||
)
|
||||
assert info.description == "1,000 items from: lesswrong.com, example.com"
|
||||
|
||||
def test_description_with_source_count(self):
|
||||
"""Description shows sources and sections for parent entities."""
|
||||
info = ModalityInfo(name="book", count=22000, source_count=400)
|
||||
assert info.description == "400 sources (22,000 sections)"
|
||||
|
||||
def test_description_with_source_count_and_domains(self):
|
||||
"""Description shows sources, sections, and domains."""
|
||||
info = ModalityInfo(
|
||||
name="comic",
|
||||
count=8000,
|
||||
source_count=50,
|
||||
domains=["xkcd.com", "smbc-comics.com"],
|
||||
)
|
||||
assert (
|
||||
info.description
|
||||
== "50 sources (8,000 sections) from: xkcd.com, smbc-comics.com"
|
||||
)
|
||||
|
||||
def test_description_formats_large_numbers(self):
|
||||
"""Numbers are formatted with commas."""
|
||||
info = ModalityInfo(name="book", count=1234567, source_count=9876)
|
||||
assert info.description == "9,876 sources (1,234,567 sections)"
|
||||
|
||||
|
||||
class TestQueryAnalysis:
|
||||
"""Tests for QueryAnalysis dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""QueryAnalysis has sensible defaults."""
|
||||
result = QueryAnalysis()
|
||||
assert result.modalities == set()
|
||||
assert result.sources == []
|
||||
assert result.cleaned_query == ""
|
||||
assert result.query_variants == []
|
||||
assert result.success is False
|
||||
|
||||
def test_with_values(self):
|
||||
"""QueryAnalysis stores provided values."""
|
||||
result = QueryAnalysis(
|
||||
modalities={"forum", "book"},
|
||||
sources=["lesswrong.com"],
|
||||
cleaned_query="test query",
|
||||
query_variants=["alternative query"],
|
||||
success=True,
|
||||
)
|
||||
assert result.modalities == {"forum", "book"}
|
||||
assert result.sources == ["lesswrong.com"]
|
||||
assert result.cleaned_query == "test query"
|
||||
assert result.query_variants == ["alternative query"]
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestBuildPrompt:
|
||||
"""Tests for _build_prompt function."""
|
||||
|
||||
def test_builds_prompt_with_modalities(self):
|
||||
"""Prompt includes modality information."""
|
||||
mock_modalities = {
|
||||
"forum": ModalityInfo(
|
||||
name="forum", count=20000, domains=["lesswrong.com"]
|
||||
),
|
||||
"book": ModalityInfo(name="book", count=22000, source_count=400),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value=mock_modalities,
|
||||
):
|
||||
prompt = _build_prompt()
|
||||
|
||||
assert "forum" in prompt
|
||||
assert "book" in prompt
|
||||
assert "20,000 items from: lesswrong.com" in prompt
|
||||
assert "400 sources (22,000 sections)" in prompt
|
||||
assert "modalities" in prompt
|
||||
assert "cleaned_query" in prompt
|
||||
|
||||
def test_builds_prompt_empty_modalities(self):
|
||||
"""Prompt handles no modalities gracefully."""
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={},
|
||||
):
|
||||
prompt = _build_prompt()
|
||||
|
||||
assert "no content indexed yet" in prompt
|
||||
|
||||
def test_prompt_contains_json_structure(self):
|
||||
"""Prompt includes expected JSON structure."""
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={"test": ModalityInfo(name="test", count=100)},
|
||||
):
|
||||
prompt = _build_prompt()
|
||||
|
||||
assert '"modalities": []' in prompt
|
||||
assert '"sources": []' in prompt
|
||||
assert '"cleaned_query": ""' in prompt
|
||||
assert '"query_variants": []' in prompt
|
||||
|
||||
def test_prompt_contains_guidelines(self):
|
||||
"""Prompt includes usage guidelines."""
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={"forum": ModalityInfo(name="forum", count=100)},
|
||||
):
|
||||
prompt = _build_prompt()
|
||||
|
||||
assert "lesswrong" in prompt.lower()
|
||||
assert "comic" in prompt.lower()
|
||||
assert "Remove" in prompt
|
||||
assert "Return ONLY valid JSON" in prompt
|
||||
|
||||
|
||||
class TestAnalyzeQuery:
|
||||
"""Tests for analyze_query async function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_analysis_on_success(self):
|
||||
"""Successfully parses LLM JSON response."""
|
||||
mock_response = json.dumps(
|
||||
{
|
||||
"modalities": ["forum"],
|
||||
"sources": ["lesswrong.com"],
|
||||
"cleaned_query": "rationality concepts",
|
||||
"query_variants": ["rational thinking", "epistemic rationality"],
|
||||
}
|
||||
)
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_provider.agenerate = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis.create_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={"forum": ModalityInfo(name="forum", count=100)},
|
||||
):
|
||||
# Clear any cached results
|
||||
from memory.api.search import query_analysis
|
||||
query_analysis._analysis_cache = {}
|
||||
|
||||
result = await analyze_query("something on lesswrong about rationality")
|
||||
|
||||
assert result.success is True
|
||||
assert result.modalities == {"forum"}
|
||||
assert result.sources == ["lesswrong.com"]
|
||||
assert result.cleaned_query == "rationality concepts"
|
||||
assert "rational thinking" in result.query_variants
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_markdown_code_blocks(self):
|
||||
"""Strips markdown code blocks from response."""
|
||||
mock_response = """```json
|
||||
{
|
||||
"modalities": ["book"],
|
||||
"sources": [],
|
||||
"cleaned_query": "test query",
|
||||
"query_variants": []
|
||||
}
|
||||
```"""
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_provider.agenerate = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis.create_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={"book": ModalityInfo(name="book", count=100)},
|
||||
):
|
||||
from memory.api.search import query_analysis
|
||||
query_analysis._analysis_cache = {}
|
||||
|
||||
result = await analyze_query("find something in a book")
|
||||
|
||||
assert result.success is True
|
||||
assert result.modalities == {"book"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_invalid_json(self):
|
||||
"""Returns default result on invalid JSON."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.agenerate = AsyncMock(return_value="not valid json {{{")
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis.create_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={},
|
||||
):
|
||||
from memory.api.search import query_analysis
|
||||
query_analysis._analysis_cache = {}
|
||||
|
||||
result = await analyze_query("test query")
|
||||
|
||||
assert result.success is False
|
||||
assert result.cleaned_query == "test query"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_timeout(self):
|
||||
"""Returns default result on timeout."""
|
||||
import asyncio
|
||||
|
||||
async def slow_response(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
return "{}"
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_provider.agenerate = slow_response
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis.create_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={},
|
||||
):
|
||||
from memory.api.search import query_analysis
|
||||
query_analysis._analysis_cache = {}
|
||||
|
||||
result = await analyze_query("test query", timeout=0.01)
|
||||
|
||||
assert result.success is False
|
||||
assert result.cleaned_query == "test query"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caches_results(self):
|
||||
"""Caches analysis results for repeated queries."""
|
||||
mock_response = json.dumps(
|
||||
{
|
||||
"modalities": [],
|
||||
"sources": [],
|
||||
"cleaned_query": "cached query",
|
||||
"query_variants": [],
|
||||
}
|
||||
)
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_provider.agenerate = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis.create_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={},
|
||||
):
|
||||
from memory.api.search import query_analysis
|
||||
query_analysis._analysis_cache = {}
|
||||
|
||||
# First call
|
||||
result1 = await analyze_query("test query for caching")
|
||||
# Second call (should use cache)
|
||||
result2 = await analyze_query("test query for caching")
|
||||
|
||||
# Provider should only be called once
|
||||
assert mock_provider.agenerate.call_count == 1
|
||||
assert result1.cleaned_query == result2.cleaned_query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_case_insensitive(self):
|
||||
"""Cache key is case-insensitive."""
|
||||
mock_response = json.dumps(
|
||||
{
|
||||
"modalities": [],
|
||||
"sources": [],
|
||||
"cleaned_query": "test",
|
||||
"query_variants": [],
|
||||
}
|
||||
)
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_provider.agenerate = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis.create_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={},
|
||||
):
|
||||
from memory.api.search import query_analysis
|
||||
query_analysis._analysis_cache = {}
|
||||
|
||||
await analyze_query("Test Query")
|
||||
await analyze_query("test query")
|
||||
await analyze_query("TEST QUERY")
|
||||
|
||||
# All variations should hit the same cache entry
|
||||
assert mock_provider.agenerate.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_empty_response(self):
|
||||
"""Handles empty LLM response gracefully."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.agenerate = AsyncMock(return_value="")
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis.create_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={},
|
||||
):
|
||||
from memory.api.search import query_analysis
|
||||
query_analysis._analysis_cache = {}
|
||||
|
||||
result = await analyze_query("test query")
|
||||
|
||||
assert result.success is False
|
||||
assert result.cleaned_query == "test query"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_none_response(self):
|
||||
"""Handles None LLM response gracefully."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.agenerate = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis.create_provider",
|
||||
return_value=mock_provider,
|
||||
):
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._get_available_modalities",
|
||||
return_value={},
|
||||
):
|
||||
from memory.api.search import query_analysis
|
||||
query_analysis._analysis_cache = {}
|
||||
|
||||
result = await analyze_query("test query")
|
||||
|
||||
assert result.success is False
|
||||
|
||||
|
||||
class TestGetAvailableModalities:
|
||||
"""Tests for _get_available_modalities function."""
|
||||
|
||||
def test_uses_cache_when_fresh(self):
|
||||
"""Returns cached data when cache is fresh."""
|
||||
import time
|
||||
from memory.api.search import query_analysis
|
||||
|
||||
# Set up cache
|
||||
query_analysis._modality_cache = {
|
||||
"test": ModalityInfo(name="test", count=100)
|
||||
}
|
||||
query_analysis._cache_timestamp = time.time()
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._refresh_modality_cache"
|
||||
) as mock_refresh:
|
||||
result = _get_available_modalities()
|
||||
|
||||
mock_refresh.assert_not_called()
|
||||
assert "test" in result
|
||||
|
||||
def test_refreshes_cache_when_stale(self):
|
||||
"""Refreshes cache when TTL has expired."""
|
||||
import time
|
||||
from memory.api.search import query_analysis
|
||||
|
||||
# Set up stale cache
|
||||
query_analysis._modality_cache = {}
|
||||
query_analysis._cache_timestamp = time.time() - 7200 # 2 hours ago
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._refresh_modality_cache"
|
||||
) as mock_refresh:
|
||||
_get_available_modalities()
|
||||
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
def test_refreshes_cache_when_empty(self):
|
||||
"""Refreshes cache when cache is empty."""
|
||||
import time
|
||||
from memory.api.search import query_analysis
|
||||
|
||||
query_analysis._modality_cache = {}
|
||||
query_analysis._cache_timestamp = time.time()
|
||||
|
||||
with patch(
|
||||
"memory.api.search.query_analysis._refresh_modality_cache"
|
||||
) as mock_refresh:
|
||||
_get_available_modalities()
|
||||
|
||||
mock_refresh.assert_called_once()
|
||||
@ -15,8 +15,9 @@ from memory.api.search.search import (
|
||||
apply_title_boost,
|
||||
apply_popularity_boost,
|
||||
apply_source_boosts,
|
||||
expand_query,
|
||||
fuse_scores_rrf,
|
||||
)
|
||||
from memory.api.search.constants import (
|
||||
STOPWORDS,
|
||||
QUERY_TERM_BOOST,
|
||||
TITLE_MATCH_BOOST,
|
||||
@ -591,78 +592,3 @@ def test_recency_boost_ordering(mock_make_session):
|
||||
|
||||
# Newer content should have higher score
|
||||
assert chunks[0].relevance_score > chunks[1].relevance_score
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# expand_query tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,expected_expansion",
|
||||
[
|
||||
# AI/ML abbreviations
|
||||
("ML algorithms", "artificial intelligence"), # Not "ML" -> won't have AI
|
||||
("AI safety research", "artificial intelligence"),
|
||||
("NLP models for text", "natural language processing"),
|
||||
("deep learning vs DL", "deep learning"),
|
||||
# Rationality/EA terms
|
||||
("EA organizations", "effective altruism"),
|
||||
("AGI timeline predictions", "artificial general intelligence"),
|
||||
("x-risk reduction", "existential risk"),
|
||||
# Reverse mappings
|
||||
("machine learning basics", "ml"),
|
||||
("artificial intelligence ethics", "ai"),
|
||||
],
|
||||
)
|
||||
def test_expand_query_adds_synonyms(query, expected_expansion):
|
||||
"""Should expand abbreviations and add synonyms."""
|
||||
result = expand_query(query)
|
||||
assert expected_expansion in result.lower()
|
||||
assert query in result # Original query preserved
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
"hello world", # No expansions
|
||||
"python programming", # No expansions
|
||||
"database optimization", # No expansions
|
||||
],
|
||||
)
|
||||
def test_expand_query_no_match(query):
|
||||
"""Should return original query when no expansions match."""
|
||||
result = expand_query(query)
|
||||
assert result == query
|
||||
|
||||
|
||||
def test_expand_query_case_insensitive():
|
||||
"""Should match terms regardless of case."""
|
||||
assert "artificial intelligence" in expand_query("AI research").lower()
|
||||
assert "artificial intelligence" in expand_query("ai research").lower()
|
||||
assert "artificial intelligence" in expand_query("Ai Research").lower()
|
||||
|
||||
|
||||
def test_expand_query_word_boundaries():
|
||||
"""Should only match whole words, not partial matches."""
|
||||
# "mail" contains "ai" but shouldn't trigger expansion
|
||||
result = expand_query("email server")
|
||||
assert result == "email server"
|
||||
|
||||
# "claim" contains "ai" but shouldn't trigger expansion
|
||||
result = expand_query("insurance claim")
|
||||
assert result == "insurance claim"
|
||||
|
||||
|
||||
def test_expand_query_multiple_terms():
|
||||
"""Should expand multiple matching terms."""
|
||||
result = expand_query("AI and ML applications")
|
||||
assert "artificial intelligence" in result.lower()
|
||||
assert "machine learning" in result.lower()
|
||||
|
||||
|
||||
def test_expand_query_preserves_original():
|
||||
"""Should preserve original query text."""
|
||||
original = "AI safety research"
|
||||
result = expand_query(original)
|
||||
assert result.startswith(original)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user