From b10a1fb13039c17531607784ee1978bd6c461ac5 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 1 Jun 2025 00:11:21 +0200 Subject: [PATCH] initial memory tools --- requirements-api.txt | 1 + requirements-common.txt | 4 +- src/memory/api/MCP/tools.py | 654 ++++++++++++++++++++ src/memory/api/admin.py | 16 + src/memory/api/app.py | 16 +- src/memory/api/search.py | 293 +++++++-- src/memory/common/collections.py | 60 +- src/memory/common/db/models/source_item.py | 16 +- src/memory/common/db/models/source_items.py | 58 +- src/memory/common/embedding.py | 57 +- src/memory/common/extract.py | 2 +- src/memory/common/formatters/observation.py | 5 +- src/memory/common/summarizer.py | 4 + src/memory/mcp/server.py | 62 -- 14 files changed, 1060 insertions(+), 188 deletions(-) create mode 100644 src/memory/api/MCP/tools.py delete mode 100644 src/memory/mcp/server.py diff --git a/requirements-api.txt b/requirements-api.txt index 5e44c4f..fa6732f 100644 --- a/requirements-api.txt +++ b/requirements-api.txt @@ -3,3 +3,4 @@ uvicorn==0.29.0 python-jose==3.3.0 python-multipart==0.0.9 sqladmin +mcp==1.9.2 \ No newline at end of file diff --git a/requirements-common.txt b/requirements-common.txt index c65d20c..dbf0da5 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -5,4 +5,6 @@ alembic==1.13.1 dotenv==0.9.9 voyageai==0.3.2 qdrant-client==1.9.0 -anthropic==0.18.1 \ No newline at end of file +anthropic==0.18.1 + +bm25s[full]==0.2.13 \ No newline at end of file diff --git a/src/memory/api/MCP/tools.py b/src/memory/api/MCP/tools.py new file mode 100644 index 0000000..77ba76b --- /dev/null +++ b/src/memory/api/MCP/tools.py @@ -0,0 +1,654 @@ +""" +MCP tools for the epistemic sparring partner system. +""" + +from datetime import datetime, timezone +from hashlib import sha256 +import logging +import uuid +from typing import cast +from mcp.server.fastmcp import FastMCP +from memory.common.db.models.source_item import SourceItem +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy import func, cast as sql_cast, Text +from memory.common.db.connection import make_session + +from memory.common import extract +from memory.common.db.models import AgentObservation +from memory.api.search import search, SearchFilters +from memory.common.formatters import observation +from memory.workers.tasks.content_processing import process_content_item + +logger = logging.getLogger(__name__) + +# Create MCP server instance +mcp = FastMCP("memory", stateless=True) + + +@mcp.tool() +async def get_all_tags() -> list[str]: + """ + Get all unique tags used across the entire knowledge base. + + Purpose: + This tool retrieves all tags that have been used in the system, both from + AI observations (created with 'observe') and other content. Use it to + understand the tag taxonomy, ensure consistency, or discover related topics. + + When to use: + - Before creating new observations, to use consistent tag naming + - To explore what topics/contexts have been tracked + - To build tag filters for search operations + - To understand the user's areas of interest + - For tag autocomplete or suggestion features + + Returns: + Sorted list of all unique tags in the system. Tags follow patterns like: + - Topics: "machine-learning", "functional-programming" + - Projects: "project:website-redesign" + - Contexts: "context:work", "context:late-night" + - Domains: "domain:finance" + + Example: + # Get all tags to ensure consistency + tags = await get_all_tags() + # Returns: ["ai-safety", "context:work", "functional-programming", + # "machine-learning", "project:thesis", ...] + + # Use to check if a topic has been discussed before + if "quantum-computing" in tags: + # Search for related observations + observations = await search_observations( + query="quantum computing", + tags=["quantum-computing"] + ) + """ + with make_session() as session: + tags_query = session.query(func.unnest(SourceItem.tags)).distinct() + return sorted({row[0] for row in tags_query if row[0] is not None}) + + +@mcp.tool() +async def get_all_subjects() -> list[str]: + """ + Get all unique subjects from observations about the user. + + Purpose: + This tool retrieves all subject identifiers that have been used in + observations (created with 'observe'). Subjects are the consistent + identifiers for what observations are about. Use this to understand + what aspects of the user have been tracked and ensure consistency. + + When to use: + - Before creating new observations, to use existing subject names + - To discover what aspects of the user have been observed + - To build subject filters for targeted searches + - To ensure consistent naming across observations + - To get an overview of the user model + + Returns: + Sorted list of all unique subjects. Common patterns include: + - "programming_style", "programming_philosophy" + - "work_habits", "work_schedule" + - "ai_beliefs", "ai_safety_beliefs" + - "learning_preferences" + - "communication_style" + + Example: + # Get all subjects to ensure consistency + subjects = await get_all_subjects() + # Returns: ["ai_safety_beliefs", "architecture_preferences", + # "programming_philosophy", "work_schedule", ...] + + # Use to check what we know about the user + if "programming_style" in subjects: + # Get all programming-related observations + observations = await search_observations( + query="programming", + subject="programming_style" + ) + + Best practices: + - Always check existing subjects before creating new ones + - Use snake_case for consistency + - Be specific but not too granular + - Group related observations under same subject + """ + with make_session() as session: + return sorted( + r.subject for r in session.query(AgentObservation.subject).distinct() + ) + + +@mcp.tool() +async def get_all_observation_types() -> list[str]: + """ + Get all unique observation types that have been used. + + Purpose: + This tool retrieves the distinct observation types that have been recorded + in the system. While the standard types are predefined (belief, preference, + behavior, contradiction, general), this shows what's actually been used. + Helpful for understanding the distribution of observation types. + + When to use: + - To see what types of observations have been made + - To understand the balance of different observation types + - To check if all standard types are being utilized + - For analytics or reporting on observation patterns + + Standard types: + - "belief": Opinions or beliefs the user holds + - "preference": Things they prefer or favor + - "behavior": Patterns in how they act or work + - "contradiction": Noted inconsistencies + - "general": Observations that don't fit other categories + + Returns: + List of observation types that have actually been used in the system. + + Example: + # Check what types of observations exist + types = await get_all_observation_types() + # Returns: ["behavior", "belief", "contradiction", "preference"] + + # Use to analyze observation distribution + for obs_type in types: + observations = await search_observations( + query="", + observation_types=[obs_type], + limit=100 + ) + print(f"{obs_type}: {len(observations)} observations") + """ + with make_session() as session: + return sorted( + { + r.observation_type + for r in session.query(AgentObservation.observation_type).distinct() + if r.observation_type is not None + } + ) + + +@mcp.tool() +async def search_knowledge_base( + query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10 +) -> list[dict]: + """ + Search through the user's stored knowledge and content. + + Purpose: + This tool searches the user's personal knowledge base - a collection of + their saved content including emails, documents, blog posts, books, and + more. Use this alongside 'search_observations' to build a complete picture: + - search_knowledge_base: Finds user's actual content and information + - search_observations: Finds AI-generated insights about the user + Together they enable deeply personalized, context-aware assistance. + + When to use: + - User asks about something they've read/written/received + - You need to find specific content the user has saved + - User references a document, email, or article + - To provide quotes or information from user's sources + - To understand context from user's past communications + - When user says "that article about..." or similar references + + How it works: + Uses hybrid search combining semantic understanding with keyword matching. + This means it finds content based on meaning AND specific terms, giving + you the best of both approaches. Results are ranked by relevance. + + Args: + query: Natural language search query. Be descriptive about what you're + looking for. The search understands meaning but also values exact terms. + Examples: + - "email about project deadline from last week" + - "functional programming articles comparing Haskell and Scala" + - "that blog post about AI safety and alignment" + - "recipe for chocolate cake Sarah sent me" + Pro tip: Include both concepts and specific keywords for best results. + + previews: Whether to include content snippets in results. + - True: Returns preview text and image previews (useful for quick scanning) + - False: Returns just metadata (faster, less data) + Default is False. + + modalities: Types of content to search. Leave empty to search all. + Available types: + - 'email': Email messages + - 'blog': Blog posts and articles + - 'book': Book sections and ebooks + - 'forum': Forum posts (e.g., LessWrong, Reddit) + - 'observation': AI observations (use search_observations instead) + - 'photo': Images with extracted text + - 'comic': Comics and graphic content + - 'webpage': General web pages + Examples: + - ["email"] - only emails + - ["blog", "forum"] - articles and forum posts + - [] - search everything + + limit: Maximum results to return (1-100). Default 10. + Increase for comprehensive searches, decrease for quick lookups. + + Returns: + List of search results ranked by relevance, each containing: + - id: Unique identifier for the source item + - score: Relevance score (0-1, higher is better) + - chunks: Matching content segments with metadata + - content: Full details including: + - For emails: sender, recipient, subject, date + - For blogs: author, title, url, publish date + - For books: title, author, chapter info + - Type-specific fields for each modality + - filename: Path to file if content is stored on disk + + Examples: + # Find specific email + results = await search_knowledge_base( + query="Sarah deadline project proposal next Friday", + modalities=["email"], + previews=True, + limit=5 + ) + + # Search for technical articles + results = await search_knowledge_base( + query="functional programming monads category theory", + modalities=["blog", "book"], + limit=20 + ) + + # Find everything about a topic + results = await search_knowledge_base( + query="machine learning deployment kubernetes docker", + previews=True + ) + + # Quick lookup of a remembered document + results = await search_knowledge_base( + query="tax forms 2023 accountant recommendations", + modalities=["email"], + limit=3 + ) + + Best practices: + - Include context in queries ("email from Sarah" vs just "Sarah") + - Use modalities to filter when you know the content type + - Enable previews when you need to verify content before using + - Combine with search_observations for complete context + - Higher scores (>0.7) indicate strong matches + - If no results, try broader queries or different phrasing + """ + logger.info(f"MCP search for: {query}") + + upload_data = extract.extract_text(query) + results = await search( + upload_data, + previews=previews, + modalities=modalities, + limit=limit, + min_text_score=0.3, + min_multimodal_score=0.3, + ) + + # Convert SearchResult objects to dictionaries for MCP + return [result.model_dump() for result in results] + + +@mcp.tool() +async def observe( + content: str, + subject: str, + observation_type: str = "general", + confidence: float = 0.8, + evidence: dict | None = None, + tags: list[str] | None = None, + session_id: str | None = None, + agent_model: str = "unknown", +) -> dict: + """ + Record an observation about the user to build long-term understanding. + + Purpose: + This tool is part of a memory system designed to help AI agents build a + deep, persistent understanding of users over time. Use it to record any + notable information about the user's preferences, beliefs, behaviors, or + characteristics. These observations accumulate to create a comprehensive + model of the user that improves future interactions. + + Quick Reference: + # Most common patterns: + observe(content="User prefers X over Y because...", subject="preferences", observation_type="preference") + observe(content="User always/often does X when Y", subject="work_habits", observation_type="behavior") + observe(content="User believes/thinks X about Y", subject="beliefs_on_topic", observation_type="belief") + observe(content="User said X but previously said Y", subject="topic", observation_type="contradiction") + + When to use: + - User expresses a preference or opinion + - You notice a behavioral pattern + - User reveals information about their work/life/interests + - You spot a contradiction with previous statements + - Any insight that would help understand the user better in future + + Important: Be an active observer. Don't wait to be asked - proactively record + observations throughout conversations to build understanding. + + Args: + content: The observation itself. Be specific and detailed. Write complete + thoughts that will make sense when read months later without context. + Bad: "Likes FP" + Good: "User strongly prefers functional programming paradigms, especially + pure functions and immutability, considering them more maintainable" + + subject: A consistent identifier for what this observation is about. Use + snake_case and be consistent across observations to enable tracking. + Examples: + - "programming_style" (not "coding" or "development") + - "work_habits" (not "productivity" or "work_patterns") + - "ai_safety_beliefs" (not "AI" or "artificial_intelligence") + + observation_type: Categorize the observation: + - "belief": An opinion or belief the user holds + - "preference": Something they prefer or favor + - "behavior": A pattern in how they act or work + - "contradiction": An inconsistency with previous observations + - "general": Doesn't fit other categories + + confidence: How certain you are (0.0-1.0): + - 1.0: User explicitly stated this + - 0.9: Strongly implied or demonstrated repeatedly + - 0.8: Inferred with high confidence (default) + - 0.7: Probable but with some uncertainty + - 0.6 or below: Speculative, use sparingly + + evidence: Supporting context as a dict. Include relevant details: + - "quote": Exact words from the user + - "context": What prompted this observation + - "timestamp": When this was observed + - "related_to": Connection to other topics + Example: { + "quote": "I always refactor to pure functions", + "context": "Discussing code review practices" + } + + tags: Categorization labels. Use lowercase with hyphens. Common patterns: + - Topics: "machine-learning", "web-development", "philosophy" + - Projects: "project:website-redesign", "project:thesis" + - Contexts: "context:work", "context:personal", "context:late-night" + - Domains: "domain:finance", "domain:healthcare" + + session_id: UUID string to group observations from the same conversation. + Generate one UUID per conversation and reuse it for all observations + in that conversation. Format: "550e8400-e29b-41d4-a716-446655440000" + + agent_model: Which AI model made this observation (e.g., "claude-3-opus", + "gpt-4", "claude-3.5-sonnet"). Helps track observation quality. + + Returns: + Dict with created observation details: + - id: Unique identifier for reference + - created_at: Timestamp of creation + - subject: The subject as stored + - observation_type: The type as stored + - confidence: The confidence score + - tags: List of applied tags + + Examples: + # After user mentions their coding philosophy + await observe( + content="User believes strongly in functional programming principles, " + "particularly avoiding mutable state which they call 'the root " + "of all evil'. They prioritize code purity over performance.", + subject="programming_philosophy", + observation_type="belief", + confidence=0.95, + evidence={ + "quote": "State is the root of all evil in programming", + "context": "Discussing why they chose Haskell for their project" + }, + tags=["programming", "functional-programming", "philosophy"], + session_id="550e8400-e29b-41d4-a716-446655440000", + agent_model="claude-3-opus" + ) + + # Noticing a work pattern + await observe( + content="User frequently works on complex problems late at night, " + "typically between 11pm and 3am, claiming better focus", + subject="work_schedule", + observation_type="behavior", + confidence=0.85, + evidence={ + "context": "Mentioned across multiple conversations over 2 weeks" + }, + tags=["behavior", "work-habits", "productivity", "context:late-night"], + agent_model="claude-3-opus" + ) + + # Recording a contradiction + await observe( + content="User now advocates for microservices architecture, but " + "previously argued strongly for monoliths in similar contexts", + subject="architecture_preferences", + observation_type="contradiction", + confidence=0.9, + evidence={ + "quote": "Microservices are definitely the way to go", + "context": "Designing a new system similar to one from 3 months ago" + }, + tags=["architecture", "contradiction", "software-design"], + agent_model="gpt-4" + ) + """ + # Create the observation + observation = AgentObservation( + content=content, + subject=subject, + observation_type=observation_type, + confidence=confidence, + evidence=evidence, + tags=tags or [], + session_id=uuid.UUID(session_id) if session_id else None, + agent_model=agent_model, + size=len(content), + mime_type="text/plain", + sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(), + inserted_at=datetime.now(timezone.utc), + ) + try: + with make_session() as session: + process_content_item(observation, session) + + if not observation.id: + raise ValueError("Observation not created") + + logger.info( + f"Observation created: {observation.id}, {observation.inserted_at}" + ) + return { + "id": observation.id, + "created_at": observation.inserted_at.isoformat(), + "subject": observation.subject, + "observation_type": observation.observation_type, + "confidence": cast(float, observation.confidence), + "tags": observation.tags, + } + + except Exception as e: + logger.error(f"Error creating observation: {e}") + raise + + +@mcp.tool() +async def search_observations( + query: str, + subject: str = "", + tags: list[str] | None = None, + observation_types: list[str] | None = None, + min_confidence: float = 0.5, + limit: int = 10, +) -> list[dict]: + """ + Search through observations to understand the user better. + + Purpose: + This tool searches through all observations recorded about the user using + the 'observe' tool. Use it to recall past insights, check for patterns, + find contradictions, or understand the user's preferences before responding. + The more you use this tool, the more personalized and insightful your + responses can be. + + When to use: + - Before answering questions where user preferences might matter + - When the user references something from the past + - To check if current behavior aligns with past patterns + - To find related observations on a topic + - To build context about the user's expertise or interests + - Whenever personalization would improve your response + + How it works: + Uses hybrid search combining semantic similarity with keyword matching. + Searches across multiple embedding spaces (semantic meaning and temporal + context) to find relevant observations from different angles. This approach + ensures you find both conceptually related and specifically mentioned items. + + Args: + query: Natural language description of what you're looking for. The search + matches both meaning and specific terms in observation content. + Examples: + - "programming preferences and coding style" + - "opinions about artificial intelligence and AI safety" + - "work habits productivity patterns when does user work best" + - "previous projects the user has worked on" + Pro tip: Use natural language but include key terms you expect to find. + + subject: Filter by exact subject identifier. Must match subjects used when + creating observations (e.g., "programming_style", "work_habits"). + Leave empty to search all subjects. Use this when you know the exact + subject category you want. + + tags: Filter results to only observations with these tags. Observations must + have at least one matching tag. Use the same format as when creating: + - ["programming", "functional-programming"] + - ["context:work", "project:thesis"] + - ["domain:finance", "machine-learning"] + + observation_types: Filter by type of observation: + - "belief": Opinions or beliefs the user holds + - "preference": Things they prefer or favor + - "behavior": Patterns in how they act or work + - "contradiction": Noted inconsistencies + - "general": Other observations + Leave as None to search all types. + + min_confidence: Only return observations with confidence >= this value. + - Use 0.8+ for high-confidence facts + - Use 0.5-0.7 to include inferred observations + - Default 0.5 includes most observations + Range: 0.0 to 1.0 + + limit: Maximum results to return (1-100). Default 10. Increase when you + need comprehensive understanding of a topic. + + Returns: + List of observations sorted by relevance, each containing: + - subject: What the observation is about + - content: The full observation text + - observation_type: Type of observation + - evidence: Supporting context/quotes if provided + - confidence: How certain the observation is (0-1) + - agent_model: Which AI model made the observation + - tags: All tags on this observation + - created_at: When it was observed (if available) + + Examples: + # Before discussing code architecture + results = await search_observations( + query="software architecture preferences microservices monoliths", + tags=["architecture"], + min_confidence=0.7 + ) + + # Understanding work style for scheduling + results = await search_observations( + query="when does user work best productivity schedule", + observation_types=["behavior", "preference"], + subject="work_schedule" + ) + + # Check for AI safety views before discussing AI + results = await search_observations( + query="artificial intelligence safety alignment concerns", + observation_types=["belief"], + min_confidence=0.8, + limit=20 + ) + + # Find contradictions on a topic + results = await search_observations( + query="testing methodology unit tests integration", + observation_types=["contradiction"], + tags=["testing", "software-development"] + ) + + Best practices: + - Search before making assumptions about user preferences + - Use broad queries first, then filter with tags/types if too many results + - Check for contradictions when user says something unexpected + - Higher confidence observations are more reliable + - Recent observations may override older ones on same topic + """ + source_ids = None + if tags or observation_types: + with make_session() as session: + items_query = session.query(AgentObservation.id) + + if tags: + # Use PostgreSQL array overlap operator with proper array casting + items_query = items_query.filter( + AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))), + ) + if observation_types: + items_query = items_query.filter( + AgentObservation.observation_type.in_(observation_types) + ) + source_ids = [item.id for item in items_query.all()] + if not source_ids: + return [] + + semantic_text = observation.generate_semantic_text( + subject=subject or "", + observation_type="".join(observation_types or []), + content=query, + evidence=None, + ) + temporal = observation.generate_temporal_text( + subject=subject or "", + content=query, + confidence=0, + created_at=datetime.now(timezone.utc), + ) + results = await search( + [ + extract.DataChunk(data=[query]), + extract.DataChunk(data=[semantic_text]), + extract.DataChunk(data=[temporal]), + ], + previews=True, + modalities=["semantic", "temporal"], + limit=limit, + min_text_score=0.8, + filters=SearchFilters( + subject=subject, + confidence=min_confidence, + tags=tags, + observation_types=observation_types, + source_ids=source_ids, + ), + ) + + return [ + cast(dict, cast(dict, result.model_dump()).get("content")) for result in results + ] diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py index fa4139e..5997bea 100644 --- a/src/memory/api/admin.py +++ b/src/memory/api/admin.py @@ -18,6 +18,7 @@ from memory.common.db.models import ( ArticleFeed, EmailAccount, ForumPost, + AgentObservation, ) @@ -53,6 +54,7 @@ class SourceItemAdmin(ModelView, model=SourceItem): class ChunkAdmin(ModelView, model=Chunk): column_list = ["id", "source_id", "embedding_model", "created_at"] + column_sortable_list = ["created_at"] class MailMessageAdmin(ModelView, model=MailMessage): @@ -174,9 +176,23 @@ class EmailAccountAdmin(ModelView, model=EmailAccount): column_searchable_list = ["name", "email_address"] +class AgentObservationAdmin(ModelView, model=AgentObservation): + column_list = [ + "id", + "content", + "subject", + "observation_type", + "confidence", + "evidence", + "inserted_at", + ] + column_searchable_list = ["subject", "observation_type"] + + def setup_admin(admin: Admin): """Add all admin views to the admin instance.""" admin.add_view(SourceItemAdmin) + admin.add_view(AgentObservationAdmin) admin.add_view(ChunkAdmin) admin.add_view(EmailAccountAdmin) admin.add_view(MailMessageAdmin) diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 1b81edf..4f37c39 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -2,6 +2,7 @@ FastAPI application for the knowledge base. """ +import contextlib import pathlib import logging from typing import Annotated, Optional @@ -15,15 +16,25 @@ from memory.common import extract from memory.common.db.connection import get_engine from memory.api.admin import setup_admin from memory.api.search import search, SearchResult +from memory.api.MCP.tools import mcp logger = logging.getLogger(__name__) -app = FastAPI(title="Knowledge Base API") + +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI): + async with contextlib.AsyncExitStack() as stack: + await stack.enter_async_context(mcp.session_manager.run()) + yield + + +app = FastAPI(title="Knowledge Base API", lifespan=lifespan) # SQLAdmin setup engine = get_engine() admin = Admin(app, engine) setup_admin(admin) +app.mount("/", mcp.streamable_http_app()) @app.get("/health") @@ -104,4 +115,7 @@ def main(): if __name__ == "__main__": + from memory.common.qdrant import setup_qdrant + + setup_qdrant() main() diff --git a/src/memory/api/search.py b/src/memory/api/search.py index 9195ff2..6b4b501 100644 --- a/src/memory/api/search.py +++ b/src/memory/api/search.py @@ -2,21 +2,29 @@ Search endpoints for the knowledge base API. """ +import asyncio import base64 +from hashlib import sha256 import io -from collections import defaultdict -from typing import Callable, Optional import logging +from collections import defaultdict +from typing import Any, Callable, Optional, TypedDict, NotRequired +import bm25s +import Stemmer +import qdrant_client from PIL import Image from pydantic import BaseModel -import qdrant_client from qdrant_client.http import models as qdrant_models -from memory.common import embedding, qdrant, extract, settings -from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS +from memory.common import embedding, extract, qdrant, settings +from memory.common.collections import ( + ALL_COLLECTIONS, + MULTIMODAL_COLLECTIONS, + TEXT_COLLECTIONS, +) from memory.common.db.connection import make_session -from memory.common.db.models import Chunk, SourceItem +from memory.common.db.models import Chunk logger = logging.getLogger(__name__) @@ -28,6 +36,17 @@ class AnnotatedChunk(BaseModel): preview: Optional[str | None] = None +class SourceData(BaseModel): + """Holds source item data to avoid SQLAlchemy session issues""" + + id: int + size: int | None + mime_type: str | None + filename: str | None + content: str | dict | None + content_length: int + + class SearchResponse(BaseModel): collection: str results: list[dict] @@ -38,16 +57,46 @@ class SearchResult(BaseModel): size: int mime_type: str chunks: list[AnnotatedChunk] - content: Optional[str] = None + content: Optional[str | dict] = None filename: Optional[str] = None +class SearchFilters(TypedDict): + subject: NotRequired[str | None] + confidence: NotRequired[float] + tags: NotRequired[list[str] | None] + observation_types: NotRequired[list[str] | None] + source_ids: NotRequired[list[int] | None] + + +async def with_timeout( + call, timeout: int = 2 +) -> list[tuple[SourceData, AnnotatedChunk]]: + """ + Run a function with a timeout. + + Args: + call: The function to run + timeout: The timeout in seconds + """ + try: + return await asyncio.wait_for(call, timeout=timeout) + except TimeoutError: + logger.warning(f"Search timed out after {timeout}s") + return [] + except Exception as e: + logger.error(f"Search failed: {e}") + return [] + + def annotated_chunk( chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool -) -> tuple[SourceItem, AnnotatedChunk]: +) -> tuple[SourceData, AnnotatedChunk]: def serialize_item(item: bytes | str | Image.Image) -> str | None: if not previews and not isinstance(item, str): return None + if not previews and isinstance(item, str): + return item[:100] if isinstance(item, Image.Image): buffer = io.BytesIO() @@ -68,7 +117,19 @@ def annotated_chunk( for k, v in metadata.items() if k not in ["content", "filename", "size", "content_type", "tags"] } - return chunk.source, AnnotatedChunk( + + # Prefetch all needed source data while in session + source = chunk.source + source_data = SourceData( + id=source.id, + size=source.size, + mime_type=source.mime_type, + filename=source.filename, + content=source.display_contents, + content_length=len(source.content) if source.content else 0, + ) + + return source_data, AnnotatedChunk( id=str(chunk.id), score=search_result.score, metadata=metadata, @@ -76,24 +137,28 @@ def annotated_chunk( ) -def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]: +def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]: items = defaultdict(list) + source_lookup = {} + for source, chunk in chunks: - items[source].append(chunk) + items[source.id].append(chunk) + source_lookup[source.id] = source return [ SearchResult( id=source.id, - size=source.size or len(source.content), + size=source.size or source.content_length, mime_type=source.mime_type or "text/plain", filename=source.filename and source.filename.replace( str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files" ), - content=source.display_contents, + content=source.content, chunks=sorted(chunks, key=lambda x: x.score, reverse=True), ) - for source, chunks in items.items() + for source_id, chunks in items.items() + for source in [source_lookup[source_id]] ] @@ -104,25 +169,31 @@ def query_chunks( embedder: Callable, min_score: float = 0.0, limit: int = 10, + filters: dict[str, Any] | None = None, ) -> dict[str, list[qdrant_models.ScoredPoint]]: - if not upload_data: + if not upload_data or not allowed_modalities: return {} - chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data] + chunks = [chunk for chunk in upload_data if chunk.data] if not chunks: logger.error(f"No chunks to embed for {allowed_modalities}") return {} - vector = embedder(chunks, input_type="query")[0] + logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}") + for c in chunks: + logger.error(f"Chunk: {c.data}") + vectors = embedder([c.data for c in chunks], input_type="query") return { collection: [ r + for vector in vectors for r in qdrant.search_vectors( client=client, collection_name=collection, query_vector=vector, limit=limit, + filter_params=filters, ) if r.score >= min_score ] @@ -130,6 +201,122 @@ def query_chunks( } +async def search_bm25( + query: str, + modalities: list[str], + limit: int = 10, + filters: SearchFilters = SearchFilters(), +) -> list[tuple[SourceData, AnnotatedChunk]]: + with make_session() as db: + items_query = db.query(Chunk.id, Chunk.content).filter( + Chunk.collection_name.in_(modalities) + ) + if source_ids := filters.get("source_ids"): + items_query = items_query.filter(Chunk.source_id.in_(source_ids)) + items = items_query.all() + item_ids = { + sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id + for item in items + } + corpus = [item.content.lower().strip() for item in items] + + stemmer = Stemmer.Stemmer("english") + corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer) + retriever = bm25s.BM25() + retriever.index(corpus_tokens) + + query_tokens = bm25s.tokenize(query, stemmer=stemmer) + results, scores = retriever.retrieve( + query_tokens, k=min(limit, len(corpus)), corpus=corpus + ) + + item_scores = { + item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score + for doc, score in zip(results[0], scores[0]) + } + + with make_session() as db: + chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all() + results = [] + for chunk in chunks: + # Prefetch all needed source data while in session + source = chunk.source + source_data = SourceData( + id=source.id, + size=source.size, + mime_type=source.mime_type, + filename=source.filename, + content=source.display_contents, + content_length=len(source.content) if source.content else 0, + ) + + annotated = AnnotatedChunk( + id=str(chunk.id), + score=item_scores[chunk.id], + metadata=source.as_payload(), + preview=None, + ) + results.append((source_data, annotated)) + + return results + + +async def search_embeddings( + data: list[extract.DataChunk], + previews: Optional[bool] = False, + modalities: set[str] = set(), + limit: int = 10, + min_score: float = 0.3, + filters: SearchFilters = SearchFilters(), + multimodal: bool = False, +) -> list[tuple[SourceData, AnnotatedChunk]]: + """ + Search across knowledge base using text query and optional files. + + Parameters: + - data: List of data to search in (e.g., text, images, files) + - previews: Whether to include previews in the search results + - modalities: List of modalities to search in (e.g., "text", "photo", "doc") + - limit: Maximum number of results + - min_score: Minimum score to include in the search results + - filters: Filters to apply to the search results + - multimodal: Whether to search in multimodal collections + """ + query_filters = { + "must": [ + {"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}}, + ], + } + if tags := filters.get("tags"): + query_filters["must"] += [{"key": "tags", "match": {"any": tags}}] + if observation_types := filters.get("observation_types"): + query_filters["must"] += [ + {"key": "observation_type", "match": {"any": observation_types}} + ] + + client = qdrant.get_qdrant_client() + results = query_chunks( + client, + data, + modalities, + embedding.embed_text if not multimodal else embedding.embed_mixed, + min_score=min_score, + limit=limit, + filters=query_filters, + ) + search_results = {k: results.get(k, []) for k in modalities} + + found_chunks = { + str(r.id): r for results in search_results.values() for r in results + } + with make_session() as db: + chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all() + return [ + annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False) + for chunk in chunks + ] + + async def search( data: list[extract.DataChunk], previews: Optional[bool] = False, @@ -137,6 +324,7 @@ async def search( limit: int = 10, min_text_score: float = 0.3, min_multimodal_score: float = 0.3, + filters: SearchFilters = {}, ) -> list[SearchResult]: """ Search across knowledge base using text query and optional files. @@ -150,40 +338,45 @@ async def search( Returns: - List of search results sorted by score """ - client = qdrant.get_qdrant_client() allowed_modalities = set(modalities or ALL_COLLECTIONS.keys()) - text_results = query_chunks( - client, - data, - allowed_modalities & TEXT_COLLECTIONS, - embedding.embed_text, - min_score=min_text_score, - limit=limit, - ) - multimodal_results = query_chunks( - client, - data, - allowed_modalities, - embedding.embed_mixed, - min_score=min_multimodal_score, - limit=limit, - ) - search_results = { - k: text_results.get(k, []) + multimodal_results.get(k, []) - for k in allowed_modalities - } - found_chunks = { - str(r.id): r for results in search_results.values() for r in results - } - with make_session() as db: - chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all() - logger.error(f"Found chunks: {chunks}") - - results = group_chunks( - [ - annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False) - for chunk in chunks - ] + text_embeddings_results = with_timeout( + search_embeddings( + data, + previews, + allowed_modalities & TEXT_COLLECTIONS, + limit, + min_text_score, + filters, + multimodal=False, ) + ) + multimodal_embeddings_results = with_timeout( + search_embeddings( + data, + previews, + allowed_modalities & MULTIMODAL_COLLECTIONS, + limit, + min_multimodal_score, + filters, + multimodal=True, + ) + ) + bm25_results = with_timeout( + search_bm25( + " ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]), + modalities, + limit=limit, + filters=filters, + ) + ) + + results = await asyncio.gather( + text_embeddings_results, + multimodal_embeddings_results, + bm25_results, + return_exceptions=False, + ) + + results = group_chunks([c for r in results for c in r]) return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True) diff --git a/src/memory/common/collections.py b/src/memory/common/collections.py index 76dffed..f2aee53 100644 --- a/src/memory/common/collections.py +++ b/src/memory/common/collections.py @@ -1,6 +1,7 @@ import logging from typing import Literal, NotRequired, TypedDict +from PIL import Image from memory.common import settings @@ -14,84 +15,92 @@ Vector = list[float] class Collection(TypedDict): dimension: int distance: DistanceType - model: str on_disk: NotRequired[bool] shards: NotRequired[int] + text: bool + multimodal: bool ALL_COLLECTIONS: dict[str, Collection] = { "mail": { "dimension": 1024, "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, + "text": True, + "multimodal": False, }, "chat": { "dimension": 1024, "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, + "text": True, + "multimodal": True, }, "git": { "dimension": 1024, "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, + "text": True, + "multimodal": False, }, "book": { "dimension": 1024, "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, + "text": True, + "multimodal": False, }, "blog": { "dimension": 1024, "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, + "text": True, + "multimodal": True, }, "forum": { "dimension": 1024, "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, + "text": True, + "multimodal": True, }, "text": { "dimension": 1024, "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, + "text": True, + "multimodal": False, }, - # Multimodal "photo": { "dimension": 1024, "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, + "text": False, + "multimodal": True, }, "comic": { "dimension": 1024, "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, + "text": False, + "multimodal": True, }, "doc": { "dimension": 1024, "distance": "Cosine", - "model": settings.MIXED_EMBEDDING_MODEL, + "text": False, + "multimodal": True, }, # Observations "semantic": { "dimension": 1024, "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, + "text": True, + "multimodal": False, }, "temporal": { "dimension": 1024, "distance": "Cosine", - "model": settings.TEXT_EMBEDDING_MODEL, + "text": True, + "multimodal": False, }, } TEXT_COLLECTIONS = { - coll - for coll, params in ALL_COLLECTIONS.items() - if params["model"] == settings.TEXT_EMBEDDING_MODEL + coll for coll, params in ALL_COLLECTIONS.items() if params.get("text") } MULTIMODAL_COLLECTIONS = { - coll - for coll, params in ALL_COLLECTIONS.items() - if params["model"] == settings.MIXED_EMBEDDING_MODEL + coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal") } TYPES = { @@ -119,5 +128,12 @@ def get_modality(mime_type: str) -> str: return "unknown" -def collection_model(collection: str) -> str | None: - return ALL_COLLECTIONS.get(collection, {}).get("model") +def collection_model( + collection: str, text: str, images: list[Image.Image] +) -> str | None: + config = ALL_COLLECTIONS.get(collection, {}) + if images and config.get("multimodal"): + return settings.MIXED_EMBEDDING_MODEL + if text and config.get("text"): + return settings.TEXT_EMBEDDING_MODEL + return "unknown" diff --git a/src/memory/common/db/models/source_item.py b/src/memory/common/db/models/source_item.py index b5656b4..56e9a44 100644 --- a/src/memory/common/db/models/source_item.py +++ b/src/memory/common/db/models/source_item.py @@ -205,7 +205,9 @@ class SourceItem(Base): id = Column(BigInteger, primary_key=True) modality = Column(Text, nullable=False) sha256 = Column(BYTEA, nullable=False, unique=True) - inserted_at = Column(DateTime(timezone=True), server_default=func.now()) + inserted_at = Column( + DateTime(timezone=True), server_default=func.now(), default=func.now() + ) tags = Column(ARRAY(Text), nullable=False, server_default="{}") size = Column(Integer) mime_type = Column(Text) @@ -261,14 +263,15 @@ class SourceItem(Base): images = [c for c in data.data if isinstance(c, Image.Image)] image_names = image_filenames(chunk_id, images) + modality = data.modality or cast(str, self.modality) chunk = Chunk( id=chunk_id, source=self, content=text or None, images=images, file_paths=image_names, - collection_name=data.collection_name or cast(str, self.modality), - embedding_model=collections.collection_model(cast(str, self.modality)), + collection_name=modality, + embedding_model=collections.collection_model(modality, text, images), item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata), ) return chunk @@ -284,5 +287,8 @@ class SourceItem(Base): } @property - def display_contents(self) -> str | None: - return cast(str | None, self.content) or cast(str | None, self.filename) + def display_contents(self) -> str | dict | None: + return { + "tags": self.tags, + "size": self.size, + } diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index ceb7e7d..985a2a6 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -570,13 +570,23 @@ class AgentObservation(SourceItem): """Get all contradictions involving this observation.""" return self.contradictions_as_first + self.contradictions_as_second - def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[extract.DataChunk]: + @property + def display_contents(self) -> dict: + return { + "subject": self.subject, + "content": self.content, + "observation_type": self.observation_type, + "evidence": self.evidence, + "confidence": self.confidence, + "agent_model": self.agent_model, + "tags": self.tags, + } + + def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: """ Generate multiple chunks for different embedding dimensions. Each chunk goes to a different Qdrant collection for specialized search. """ - chunks = [] - # 1. Semantic chunk - standard content representation semantic_text = observation.generate_semantic_text( cast(str, self.subject), @@ -584,12 +594,10 @@ class AgentObservation(SourceItem): cast(str, self.content), cast(observation.Evidence, self.evidence), ) - chunks.append( - extract.DataChunk( - data=[semantic_text], - metadata=merge_metadata(metadata, {"embedding_type": "semantic"}), - collection_name="semantic", - ) + semantic_chunk = extract.DataChunk( + data=[semantic_text], + metadata=merge_metadata(metadata, {"embedding_type": "semantic"}), + modality="semantic", ) # 2. Temporal chunk - time-aware representation @@ -599,14 +607,29 @@ class AgentObservation(SourceItem): cast(float, self.confidence), cast(datetime, self.inserted_at), ) - chunks.append( - extract.DataChunk( - data=[temporal_text], - metadata=merge_metadata(metadata, {"embedding_type": "temporal"}), - collection_name="temporal", - ) + temporal_chunk = extract.DataChunk( + data=[temporal_text], + metadata=merge_metadata(metadata, {"embedding_type": "temporal"}), + modality="temporal", ) + others = [ + self._make_chunk( + extract.DataChunk( + data=[i], + metadata=merge_metadata(metadata, {"embedding_type": "semantic"}), + modality="semantic", + ) + ) + for i in [ + self.content, + self.subject, + self.observation_type, + self.evidence.get("quote", ""), + ] + if i + ] + # TODO: Add more embedding dimensions here: # 3. Epistemic chunk - belief structure focused # epistemic_text = self._generate_epistemic_text() @@ -632,4 +655,7 @@ class AgentObservation(SourceItem): # collection_name="observations_relational" # )) - return chunks + return [ + self._make_chunk(semantic_chunk), + self._make_chunk(temporal_chunk), + ] + others diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index caddbf8..44ae4ea 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) def embed_chunks( - chunks: list[str] | list[list[extract.MulitmodalChunk]], + chunks: list[list[extract.MulitmodalChunk]], model: str = settings.TEXT_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", ) -> list[Vector]: @@ -24,52 +24,50 @@ def embed_chunks( vo = voyageai.Client() # type: ignore if model == settings.MIXED_EMBEDDING_MODEL: return vo.multimodal_embed( - chunks, # type: ignore + chunks, model=model, input_type=input_type, ).embeddings - return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore + + texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks] + return cast( + list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings + ) + + +def break_chunk( + chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS +) -> list[extract.MulitmodalChunk]: + result = [] + for c in chunk: + if isinstance(c, str): + result += chunk_text(c, chunk_size, OVERLAP_TOKENS) + else: + result.append(chunk) + return result def embed_text( - texts: list[str], + chunks: list[list[extract.MulitmodalChunk]], model: str = settings.TEXT_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", chunk_size: int = DEFAULT_CHUNK_TOKENS, ) -> list[Vector]: - chunks = [ - c - for text in texts - if isinstance(text, str) - for c in chunk_text(text, chunk_size, OVERLAP_TOKENS) - if c.strip() - ] - if not chunks: + chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks] + if not any(chunked_chunks): return [] - try: - return embed_chunks(chunks, model, input_type) - except voyageai.error.InvalidRequestError as e: # type: ignore - logger.error(f"Error embedding text: {e}") - logger.debug(f"Text: {texts}") - raise + return embed_chunks(chunked_chunks, model, input_type) def embed_mixed( - items: list[extract.MulitmodalChunk], + items: list[list[extract.MulitmodalChunk]], model: str = settings.MIXED_EMBEDDING_MODEL, input_type: Literal["document", "query"] = "document", chunk_size: int = DEFAULT_CHUNK_TOKENS, ) -> list[Vector]: - def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]: - if isinstance(item, str): - return [ - c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip() - ] - return [item] - - chunks = [c for item in items for c in to_chunks(item)] - return embed_chunks([chunks], model, input_type) + chunked_chunks = [break_chunk(item, chunk_size) for item in items] + return embed_chunks(chunked_chunks, model, input_type) def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]: @@ -87,6 +85,9 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]: def embed_source_item(item: SourceItem) -> list[Chunk]: chunks = list(item.data_chunks()) + logger.error( + f"Embedding source item: {item.id} - {[(c.embedding_model, c.collection_name, c.chunks) for c in chunks]}" + ) if not chunks: return [] diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 6d1ad0e..18880eb 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -21,7 +21,7 @@ class DataChunk: data: Sequence[MulitmodalChunk] metadata: dict[str, Any] = field(default_factory=dict) mime_type: str = "text/plain" - collection_name: str | None = None + modality: str | None = None @contextmanager diff --git a/src/memory/common/formatters/observation.py b/src/memory/common/formatters/observation.py index 6f2819c..f53485e 100644 --- a/src/memory/common/formatters/observation.py +++ b/src/memory/common/formatters/observation.py @@ -8,7 +8,7 @@ class Evidence(TypedDict): def generate_semantic_text( - subject: str, observation_type: str, content: str, evidence: Evidence + subject: str, observation_type: str, content: str, evidence: Evidence | None = None ) -> str: """Generate text optimized for semantic similarity search.""" parts = [ @@ -54,8 +54,9 @@ def generate_temporal_text( f"Time: {time_of_day} on {day_of_week} ({time_period})", f"Subject: {subject}", f"Observation: {content}", - f"Confidence: {confidence}", ] + if confidence: + parts.append(f"Confidence: {confidence}") return " | ".join(parts) diff --git a/src/memory/common/summarizer.py b/src/memory/common/summarizer.py index 7fdcac7..8d6fff6 100644 --- a/src/memory/common/summarizer.py +++ b/src/memory/common/summarizer.py @@ -9,6 +9,8 @@ logger = logging.getLogger(__name__) TAGS_PROMPT = """ The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes. +Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning". + Return your response as JSON with this format: {{ "summary": "{summary}", @@ -23,6 +25,8 @@ SUMMARY_PROMPT = """ Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters). Also provide 3-5 relevant tags that capture the main topics or themes. +Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning". + Return your response as JSON with this format: {{ "summary": "your summary here", diff --git a/src/memory/mcp/server.py b/src/memory/mcp/server.py deleted file mode 100644 index 96ea9c2..0000000 --- a/src/memory/mcp/server.py +++ /dev/null @@ -1,62 +0,0 @@ -import argparse -import logging -from typing import Any -from fastapi import UploadFile -import httpx -from mcp.server.fastmcp import FastMCP - -SERVER = "http://localhost:8000" - - -logger = logging.getLogger(__name__) -mcp = FastMCP("memory") - - -async def make_request( - path: str, - method: str, - data: dict | None = None, - json: dict | None = None, - files: list[UploadFile] | None = None, -) -> httpx.Response: - async with httpx.AsyncClient() as client: - return await client.request( - method, - f"{SERVER}/{path}", - data=data, - json=json, - files=files, # type: ignore - ) - - -async def post_data(path: str, data: dict | None = None) -> httpx.Response: - return await make_request(path, "POST", data=data) - - -@mcp.tool() -async def search( - query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10 -) -> list[dict[str, Any]]: - logger.error(f"Searching for {query}") - resp = await post_data( - "search", - { - "query": query, - "previews": previews, - "modalities": modalities, - "limit": limit, - }, - ) - - return resp.json() - - -if __name__ == "__main__": - # Initialize and run the server - args = argparse.ArgumentParser() - args.add_argument("--server", type=str) - args = args.parse_args() - - SERVER = args.server - - mcp.run(transport=args.transport)