initial memory tools

This commit is contained in:
Daniel O'Connell 2025-06-01 00:11:21 +02:00
parent 1dd93929c1
commit b10a1fb130
14 changed files with 1060 additions and 188 deletions

View File

@ -3,3 +3,4 @@ uvicorn==0.29.0
python-jose==3.3.0
python-multipart==0.0.9
sqladmin
mcp==1.9.2

View File

@ -6,3 +6,5 @@ dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
anthropic==0.18.1
bm25s[full]==0.2.13

654
src/memory/api/MCP/tools.py Normal file
View File

@ -0,0 +1,654 @@
"""
MCP tools for the epistemic sparring partner system.
"""
from datetime import datetime, timezone
from hashlib import sha256
import logging
import uuid
from typing import cast
from mcp.server.fastmcp import FastMCP
from memory.common.db.models.source_item import SourceItem
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy import func, cast as sql_cast, Text
from memory.common.db.connection import make_session
from memory.common import extract
from memory.common.db.models import AgentObservation
from memory.api.search import search, SearchFilters
from memory.common.formatters import observation
from memory.workers.tasks.content_processing import process_content_item
logger = logging.getLogger(__name__)
# Create MCP server instance
mcp = FastMCP("memory", stateless=True)
@mcp.tool()
async def get_all_tags() -> list[str]:
"""
Get all unique tags used across the entire knowledge base.
Purpose:
This tool retrieves all tags that have been used in the system, both from
AI observations (created with 'observe') and other content. Use it to
understand the tag taxonomy, ensure consistency, or discover related topics.
When to use:
- Before creating new observations, to use consistent tag naming
- To explore what topics/contexts have been tracked
- To build tag filters for search operations
- To understand the user's areas of interest
- For tag autocomplete or suggestion features
Returns:
Sorted list of all unique tags in the system. Tags follow patterns like:
- Topics: "machine-learning", "functional-programming"
- Projects: "project:website-redesign"
- Contexts: "context:work", "context:late-night"
- Domains: "domain:finance"
Example:
# Get all tags to ensure consistency
tags = await get_all_tags()
# Returns: ["ai-safety", "context:work", "functional-programming",
# "machine-learning", "project:thesis", ...]
# Use to check if a topic has been discussed before
if "quantum-computing" in tags:
# Search for related observations
observations = await search_observations(
query="quantum computing",
tags=["quantum-computing"]
)
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@mcp.tool()
async def get_all_subjects() -> list[str]:
"""
Get all unique subjects from observations about the user.
Purpose:
This tool retrieves all subject identifiers that have been used in
observations (created with 'observe'). Subjects are the consistent
identifiers for what observations are about. Use this to understand
what aspects of the user have been tracked and ensure consistency.
When to use:
- Before creating new observations, to use existing subject names
- To discover what aspects of the user have been observed
- To build subject filters for targeted searches
- To ensure consistent naming across observations
- To get an overview of the user model
Returns:
Sorted list of all unique subjects. Common patterns include:
- "programming_style", "programming_philosophy"
- "work_habits", "work_schedule"
- "ai_beliefs", "ai_safety_beliefs"
- "learning_preferences"
- "communication_style"
Example:
# Get all subjects to ensure consistency
subjects = await get_all_subjects()
# Returns: ["ai_safety_beliefs", "architecture_preferences",
# "programming_philosophy", "work_schedule", ...]
# Use to check what we know about the user
if "programming_style" in subjects:
# Get all programming-related observations
observations = await search_observations(
query="programming",
subject="programming_style"
)
Best practices:
- Always check existing subjects before creating new ones
- Use snake_case for consistency
- Be specific but not too granular
- Group related observations under same subject
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@mcp.tool()
async def get_all_observation_types() -> list[str]:
"""
Get all unique observation types that have been used.
Purpose:
This tool retrieves the distinct observation types that have been recorded
in the system. While the standard types are predefined (belief, preference,
behavior, contradiction, general), this shows what's actually been used.
Helpful for understanding the distribution of observation types.
When to use:
- To see what types of observations have been made
- To understand the balance of different observation types
- To check if all standard types are being utilized
- For analytics or reporting on observation patterns
Standard types:
- "belief": Opinions or beliefs the user holds
- "preference": Things they prefer or favor
- "behavior": Patterns in how they act or work
- "contradiction": Noted inconsistencies
- "general": Observations that don't fit other categories
Returns:
List of observation types that have actually been used in the system.
Example:
# Check what types of observations exist
types = await get_all_observation_types()
# Returns: ["behavior", "belief", "contradiction", "preference"]
# Use to analyze observation distribution
for obs_type in types:
observations = await search_observations(
query="",
observation_types=[obs_type],
limit=100
)
print(f"{obs_type}: {len(observations)} observations")
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)
@mcp.tool()
async def search_knowledge_base(
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
) -> list[dict]:
"""
Search through the user's stored knowledge and content.
Purpose:
This tool searches the user's personal knowledge base - a collection of
their saved content including emails, documents, blog posts, books, and
more. Use this alongside 'search_observations' to build a complete picture:
- search_knowledge_base: Finds user's actual content and information
- search_observations: Finds AI-generated insights about the user
Together they enable deeply personalized, context-aware assistance.
When to use:
- User asks about something they've read/written/received
- You need to find specific content the user has saved
- User references a document, email, or article
- To provide quotes or information from user's sources
- To understand context from user's past communications
- When user says "that article about..." or similar references
How it works:
Uses hybrid search combining semantic understanding with keyword matching.
This means it finds content based on meaning AND specific terms, giving
you the best of both approaches. Results are ranked by relevance.
Args:
query: Natural language search query. Be descriptive about what you're
looking for. The search understands meaning but also values exact terms.
Examples:
- "email about project deadline from last week"
- "functional programming articles comparing Haskell and Scala"
- "that blog post about AI safety and alignment"
- "recipe for chocolate cake Sarah sent me"
Pro tip: Include both concepts and specific keywords for best results.
previews: Whether to include content snippets in results.
- True: Returns preview text and image previews (useful for quick scanning)
- False: Returns just metadata (faster, less data)
Default is False.
modalities: Types of content to search. Leave empty to search all.
Available types:
- 'email': Email messages
- 'blog': Blog posts and articles
- 'book': Book sections and ebooks
- 'forum': Forum posts (e.g., LessWrong, Reddit)
- 'observation': AI observations (use search_observations instead)
- 'photo': Images with extracted text
- 'comic': Comics and graphic content
- 'webpage': General web pages
Examples:
- ["email"] - only emails
- ["blog", "forum"] - articles and forum posts
- [] - search everything
limit: Maximum results to return (1-100). Default 10.
Increase for comprehensive searches, decrease for quick lookups.
Returns:
List of search results ranked by relevance, each containing:
- id: Unique identifier for the source item
- score: Relevance score (0-1, higher is better)
- chunks: Matching content segments with metadata
- content: Full details including:
- For emails: sender, recipient, subject, date
- For blogs: author, title, url, publish date
- For books: title, author, chapter info
- Type-specific fields for each modality
- filename: Path to file if content is stored on disk
Examples:
# Find specific email
results = await search_knowledge_base(
query="Sarah deadline project proposal next Friday",
modalities=["email"],
previews=True,
limit=5
)
# Search for technical articles
results = await search_knowledge_base(
query="functional programming monads category theory",
modalities=["blog", "book"],
limit=20
)
# Find everything about a topic
results = await search_knowledge_base(
query="machine learning deployment kubernetes docker",
previews=True
)
# Quick lookup of a remembered document
results = await search_knowledge_base(
query="tax forms 2023 accountant recommendations",
modalities=["email"],
limit=3
)
Best practices:
- Include context in queries ("email from Sarah" vs just "Sarah")
- Use modalities to filter when you know the content type
- Enable previews when you need to verify content before using
- Combine with search_observations for complete context
- Higher scores (>0.7) indicate strong matches
- If no results, try broader queries or different phrasing
"""
logger.info(f"MCP search for: {query}")
upload_data = extract.extract_text(query)
results = await search(
upload_data,
previews=previews,
modalities=modalities,
limit=limit,
min_text_score=0.3,
min_multimodal_score=0.3,
)
# Convert SearchResult objects to dictionaries for MCP
return [result.model_dump() for result in results]
@mcp.tool()
async def observe(
content: str,
subject: str,
observation_type: str = "general",
confidence: float = 0.8,
evidence: dict | None = None,
tags: list[str] | None = None,
session_id: str | None = None,
agent_model: str = "unknown",
) -> dict:
"""
Record an observation about the user to build long-term understanding.
Purpose:
This tool is part of a memory system designed to help AI agents build a
deep, persistent understanding of users over time. Use it to record any
notable information about the user's preferences, beliefs, behaviors, or
characteristics. These observations accumulate to create a comprehensive
model of the user that improves future interactions.
Quick Reference:
# Most common patterns:
observe(content="User prefers X over Y because...", subject="preferences", observation_type="preference")
observe(content="User always/often does X when Y", subject="work_habits", observation_type="behavior")
observe(content="User believes/thinks X about Y", subject="beliefs_on_topic", observation_type="belief")
observe(content="User said X but previously said Y", subject="topic", observation_type="contradiction")
When to use:
- User expresses a preference or opinion
- You notice a behavioral pattern
- User reveals information about their work/life/interests
- You spot a contradiction with previous statements
- Any insight that would help understand the user better in future
Important: Be an active observer. Don't wait to be asked - proactively record
observations throughout conversations to build understanding.
Args:
content: The observation itself. Be specific and detailed. Write complete
thoughts that will make sense when read months later without context.
Bad: "Likes FP"
Good: "User strongly prefers functional programming paradigms, especially
pure functions and immutability, considering them more maintainable"
subject: A consistent identifier for what this observation is about. Use
snake_case and be consistent across observations to enable tracking.
Examples:
- "programming_style" (not "coding" or "development")
- "work_habits" (not "productivity" or "work_patterns")
- "ai_safety_beliefs" (not "AI" or "artificial_intelligence")
observation_type: Categorize the observation:
- "belief": An opinion or belief the user holds
- "preference": Something they prefer or favor
- "behavior": A pattern in how they act or work
- "contradiction": An inconsistency with previous observations
- "general": Doesn't fit other categories
confidence: How certain you are (0.0-1.0):
- 1.0: User explicitly stated this
- 0.9: Strongly implied or demonstrated repeatedly
- 0.8: Inferred with high confidence (default)
- 0.7: Probable but with some uncertainty
- 0.6 or below: Speculative, use sparingly
evidence: Supporting context as a dict. Include relevant details:
- "quote": Exact words from the user
- "context": What prompted this observation
- "timestamp": When this was observed
- "related_to": Connection to other topics
Example: {
"quote": "I always refactor to pure functions",
"context": "Discussing code review practices"
}
tags: Categorization labels. Use lowercase with hyphens. Common patterns:
- Topics: "machine-learning", "web-development", "philosophy"
- Projects: "project:website-redesign", "project:thesis"
- Contexts: "context:work", "context:personal", "context:late-night"
- Domains: "domain:finance", "domain:healthcare"
session_id: UUID string to group observations from the same conversation.
Generate one UUID per conversation and reuse it for all observations
in that conversation. Format: "550e8400-e29b-41d4-a716-446655440000"
agent_model: Which AI model made this observation (e.g., "claude-3-opus",
"gpt-4", "claude-3.5-sonnet"). Helps track observation quality.
Returns:
Dict with created observation details:
- id: Unique identifier for reference
- created_at: Timestamp of creation
- subject: The subject as stored
- observation_type: The type as stored
- confidence: The confidence score
- tags: List of applied tags
Examples:
# After user mentions their coding philosophy
await observe(
content="User believes strongly in functional programming principles, "
"particularly avoiding mutable state which they call 'the root "
"of all evil'. They prioritize code purity over performance.",
subject="programming_philosophy",
observation_type="belief",
confidence=0.95,
evidence={
"quote": "State is the root of all evil in programming",
"context": "Discussing why they chose Haskell for their project"
},
tags=["programming", "functional-programming", "philosophy"],
session_id="550e8400-e29b-41d4-a716-446655440000",
agent_model="claude-3-opus"
)
# Noticing a work pattern
await observe(
content="User frequently works on complex problems late at night, "
"typically between 11pm and 3am, claiming better focus",
subject="work_schedule",
observation_type="behavior",
confidence=0.85,
evidence={
"context": "Mentioned across multiple conversations over 2 weeks"
},
tags=["behavior", "work-habits", "productivity", "context:late-night"],
agent_model="claude-3-opus"
)
# Recording a contradiction
await observe(
content="User now advocates for microservices architecture, but "
"previously argued strongly for monoliths in similar contexts",
subject="architecture_preferences",
observation_type="contradiction",
confidence=0.9,
evidence={
"quote": "Microservices are definitely the way to go",
"context": "Designing a new system similar to one from 3 months ago"
},
tags=["architecture", "contradiction", "software-design"],
agent_model="gpt-4"
)
"""
# Create the observation
observation = AgentObservation(
content=content,
subject=subject,
observation_type=observation_type,
confidence=confidence,
evidence=evidence,
tags=tags or [],
session_id=uuid.UUID(session_id) if session_id else None,
agent_model=agent_model,
size=len(content),
mime_type="text/plain",
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(),
inserted_at=datetime.now(timezone.utc),
)
try:
with make_session() as session:
process_content_item(observation, session)
if not observation.id:
raise ValueError("Observation not created")
logger.info(
f"Observation created: {observation.id}, {observation.inserted_at}"
)
return {
"id": observation.id,
"created_at": observation.inserted_at.isoformat(),
"subject": observation.subject,
"observation_type": observation.observation_type,
"confidence": cast(float, observation.confidence),
"tags": observation.tags,
}
except Exception as e:
logger.error(f"Error creating observation: {e}")
raise
@mcp.tool()
async def search_observations(
query: str,
subject: str = "",
tags: list[str] | None = None,
observation_types: list[str] | None = None,
min_confidence: float = 0.5,
limit: int = 10,
) -> list[dict]:
"""
Search through observations to understand the user better.
Purpose:
This tool searches through all observations recorded about the user using
the 'observe' tool. Use it to recall past insights, check for patterns,
find contradictions, or understand the user's preferences before responding.
The more you use this tool, the more personalized and insightful your
responses can be.
When to use:
- Before answering questions where user preferences might matter
- When the user references something from the past
- To check if current behavior aligns with past patterns
- To find related observations on a topic
- To build context about the user's expertise or interests
- Whenever personalization would improve your response
How it works:
Uses hybrid search combining semantic similarity with keyword matching.
Searches across multiple embedding spaces (semantic meaning and temporal
context) to find relevant observations from different angles. This approach
ensures you find both conceptually related and specifically mentioned items.
Args:
query: Natural language description of what you're looking for. The search
matches both meaning and specific terms in observation content.
Examples:
- "programming preferences and coding style"
- "opinions about artificial intelligence and AI safety"
- "work habits productivity patterns when does user work best"
- "previous projects the user has worked on"
Pro tip: Use natural language but include key terms you expect to find.
subject: Filter by exact subject identifier. Must match subjects used when
creating observations (e.g., "programming_style", "work_habits").
Leave empty to search all subjects. Use this when you know the exact
subject category you want.
tags: Filter results to only observations with these tags. Observations must
have at least one matching tag. Use the same format as when creating:
- ["programming", "functional-programming"]
- ["context:work", "project:thesis"]
- ["domain:finance", "machine-learning"]
observation_types: Filter by type of observation:
- "belief": Opinions or beliefs the user holds
- "preference": Things they prefer or favor
- "behavior": Patterns in how they act or work
- "contradiction": Noted inconsistencies
- "general": Other observations
Leave as None to search all types.
min_confidence: Only return observations with confidence >= this value.
- Use 0.8+ for high-confidence facts
- Use 0.5-0.7 to include inferred observations
- Default 0.5 includes most observations
Range: 0.0 to 1.0
limit: Maximum results to return (1-100). Default 10. Increase when you
need comprehensive understanding of a topic.
Returns:
List of observations sorted by relevance, each containing:
- subject: What the observation is about
- content: The full observation text
- observation_type: Type of observation
- evidence: Supporting context/quotes if provided
- confidence: How certain the observation is (0-1)
- agent_model: Which AI model made the observation
- tags: All tags on this observation
- created_at: When it was observed (if available)
Examples:
# Before discussing code architecture
results = await search_observations(
query="software architecture preferences microservices monoliths",
tags=["architecture"],
min_confidence=0.7
)
# Understanding work style for scheduling
results = await search_observations(
query="when does user work best productivity schedule",
observation_types=["behavior", "preference"],
subject="work_schedule"
)
# Check for AI safety views before discussing AI
results = await search_observations(
query="artificial intelligence safety alignment concerns",
observation_types=["belief"],
min_confidence=0.8,
limit=20
)
# Find contradictions on a topic
results = await search_observations(
query="testing methodology unit tests integration",
observation_types=["contradiction"],
tags=["testing", "software-development"]
)
Best practices:
- Search before making assumptions about user preferences
- Use broad queries first, then filter with tags/types if too many results
- Check for contradictions when user says something unexpected
- Higher confidence observations are more reliable
- Recent observations may override older ones on same topic
"""
source_ids = None
if tags or observation_types:
with make_session() as session:
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if observation_types:
items_query = items_query.filter(
AgentObservation.observation_type.in_(observation_types)
)
source_ids = [item.id for item in items_query.all()]
if not source_ids:
return []
semantic_text = observation.generate_semantic_text(
subject=subject or "",
observation_type="".join(observation_types or []),
content=query,
evidence=None,
)
temporal = observation.generate_temporal_text(
subject=subject or "",
content=query,
confidence=0,
created_at=datetime.now(timezone.utc),
)
results = await search(
[
extract.DataChunk(data=[query]),
extract.DataChunk(data=[semantic_text]),
extract.DataChunk(data=[temporal]),
],
previews=True,
modalities=["semantic", "temporal"],
limit=limit,
min_text_score=0.8,
filters=SearchFilters(
subject=subject,
confidence=min_confidence,
tags=tags,
observation_types=observation_types,
source_ids=source_ids,
),
)
return [
cast(dict, cast(dict, result.model_dump()).get("content")) for result in results
]

View File

@ -18,6 +18,7 @@ from memory.common.db.models import (
ArticleFeed,
EmailAccount,
ForumPost,
AgentObservation,
)
@ -53,6 +54,7 @@ class SourceItemAdmin(ModelView, model=SourceItem):
class ChunkAdmin(ModelView, model=Chunk):
column_list = ["id", "source_id", "embedding_model", "created_at"]
column_sortable_list = ["created_at"]
class MailMessageAdmin(ModelView, model=MailMessage):
@ -174,9 +176,23 @@ class EmailAccountAdmin(ModelView, model=EmailAccount):
column_searchable_list = ["name", "email_address"]
class AgentObservationAdmin(ModelView, model=AgentObservation):
column_list = [
"id",
"content",
"subject",
"observation_type",
"confidence",
"evidence",
"inserted_at",
]
column_searchable_list = ["subject", "observation_type"]
def setup_admin(admin: Admin):
"""Add all admin views to the admin instance."""
admin.add_view(SourceItemAdmin)
admin.add_view(AgentObservationAdmin)
admin.add_view(ChunkAdmin)
admin.add_view(EmailAccountAdmin)
admin.add_view(MailMessageAdmin)

View File

@ -2,6 +2,7 @@
FastAPI application for the knowledge base.
"""
import contextlib
import pathlib
import logging
from typing import Annotated, Optional
@ -15,15 +16,25 @@ from memory.common import extract
from memory.common.db.connection import get_engine
from memory.api.admin import setup_admin
from memory.api.search import search, SearchResult
from memory.api.MCP.tools import mcp
logger = logging.getLogger(__name__)
app = FastAPI(title="Knowledge Base API")
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(mcp.session_manager.run())
yield
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
# SQLAdmin setup
engine = get_engine()
admin = Admin(app, engine)
setup_admin(admin)
app.mount("/", mcp.streamable_http_app())
@app.get("/health")
@ -104,4 +115,7 @@ def main():
if __name__ == "__main__":
from memory.common.qdrant import setup_qdrant
setup_qdrant()
main()

View File

@ -2,21 +2,29 @@
Search endpoints for the knowledge base API.
"""
import asyncio
import base64
from hashlib import sha256
import io
from collections import defaultdict
from typing import Callable, Optional
import logging
from collections import defaultdict
from typing import Any, Callable, Optional, TypedDict, NotRequired
import bm25s
import Stemmer
import qdrant_client
from PIL import Image
from pydantic import BaseModel
import qdrant_client
from qdrant_client.http import models as qdrant_models
from memory.common import embedding, qdrant, extract, settings
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
from memory.common import embedding, extract, qdrant, settings
from memory.common.collections import (
ALL_COLLECTIONS,
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
@ -28,6 +36,17 @@ class AnnotatedChunk(BaseModel):
preview: Optional[str | None] = None
class SourceData(BaseModel):
"""Holds source item data to avoid SQLAlchemy session issues"""
id: int
size: int | None
mime_type: str | None
filename: str | None
content: str | dict | None
content_length: int
class SearchResponse(BaseModel):
collection: str
results: list[dict]
@ -38,16 +57,46 @@ class SearchResult(BaseModel):
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str] = None
content: Optional[str | dict] = None
filename: Optional[str] = None
class SearchFilters(TypedDict):
subject: NotRequired[str | None]
confidence: NotRequired[float]
tags: NotRequired[list[str] | None]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]
async def with_timeout(
call, timeout: int = 2
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Run a function with a timeout.
Args:
call: The function to run
timeout: The timeout in seconds
"""
try:
return await asyncio.wait_for(call, timeout=timeout)
except TimeoutError:
logger.warning(f"Search timed out after {timeout}s")
return []
except Exception as e:
logger.error(f"Search failed: {e}")
return []
def annotated_chunk(
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
) -> tuple[SourceItem, AnnotatedChunk]:
) -> tuple[SourceData, AnnotatedChunk]:
def serialize_item(item: bytes | str | Image.Image) -> str | None:
if not previews and not isinstance(item, str):
return None
if not previews and isinstance(item, str):
return item[:100]
if isinstance(item, Image.Image):
buffer = io.BytesIO()
@ -68,7 +117,19 @@ def annotated_chunk(
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
return chunk.source, AnnotatedChunk(
# Prefetch all needed source data while in session
source = chunk.source
source_data = SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content=source.display_contents,
content_length=len(source.content) if source.content else 0,
)
return source_data, AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
@ -76,24 +137,28 @@ def annotated_chunk(
)
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]:
items = defaultdict(list)
source_lookup = {}
for source, chunk in chunks:
items[source].append(chunk)
items[source.id].append(chunk)
source_lookup[source.id] = source
return [
SearchResult(
id=source.id,
size=source.size or len(source.content),
size=source.size or source.content_length,
mime_type=source.mime_type or "text/plain",
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=source.display_contents,
content=source.content,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
)
for source, chunks in items.items()
for source_id, chunks in items.items()
for source in [source_lookup[source_id]]
]
@ -104,25 +169,31 @@ def query_chunks(
embedder: Callable,
min_score: float = 0.0,
limit: int = 10,
filters: dict[str, Any] | None = None,
) -> dict[str, list[qdrant_models.ScoredPoint]]:
if not upload_data:
if not upload_data or not allowed_modalities:
return {}
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
chunks = [chunk for chunk in upload_data if chunk.data]
if not chunks:
logger.error(f"No chunks to embed for {allowed_modalities}")
return {}
vector = embedder(chunks, input_type="query")[0]
logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}")
for c in chunks:
logger.error(f"Chunk: {c.data}")
vectors = embedder([c.data for c in chunks], input_type="query")
return {
collection: [
r
for vector in vectors
for r in qdrant.search_vectors(
client=client,
collection_name=collection,
query_vector=vector,
limit=limit,
filter_params=filters,
)
if r.score >= min_score
]
@ -130,6 +201,122 @@ def query_chunks(
}
async def search_bm25(
query: str,
modalities: list[str],
limit: int = 10,
filters: SearchFilters = SearchFilters(),
) -> list[tuple[SourceData, AnnotatedChunk]]:
with make_session() as db:
items_query = db.query(Chunk.id, Chunk.content).filter(
Chunk.collection_name.in_(modalities)
)
if source_ids := filters.get("source_ids"):
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
items = items_query.all()
item_ids = {
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
for item in items
}
corpus = [item.content.lower().strip() for item in items]
stemmer = Stemmer.Stemmer("english")
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
results, scores = retriever.retrieve(
query_tokens, k=min(limit, len(corpus)), corpus=corpus
)
item_scores = {
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
for doc, score in zip(results[0], scores[0])
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
results = []
for chunk in chunks:
# Prefetch all needed source data while in session
source = chunk.source
source_data = SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content=source.display_contents,
content_length=len(source.content) if source.content else 0,
)
annotated = AnnotatedChunk(
id=str(chunk.id),
score=item_scores[chunk.id],
metadata=source.as_payload(),
preview=None,
)
results.append((source_data, annotated))
return results
async def search_embeddings(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_score: float = 0.3,
filters: SearchFilters = SearchFilters(),
multimodal: bool = False,
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Search across knowledge base using text query and optional files.
Parameters:
- data: List of data to search in (e.g., text, images, files)
- previews: Whether to include previews in the search results
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
- limit: Maximum number of results
- min_score: Minimum score to include in the search results
- filters: Filters to apply to the search results
- multimodal: Whether to search in multimodal collections
"""
query_filters = {
"must": [
{"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}},
],
}
if tags := filters.get("tags"):
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
if observation_types := filters.get("observation_types"):
query_filters["must"] += [
{"key": "observation_type", "match": {"any": observation_types}}
]
client = qdrant.get_qdrant_client()
results = query_chunks(
client,
data,
modalities,
embedding.embed_text if not multimodal else embedding.embed_mixed,
min_score=min_score,
limit=limit,
filters=query_filters,
)
search_results = {k: results.get(k, []) for k in modalities}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
return [
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
@ -137,6 +324,7 @@ async def search(
limit: int = 10,
min_text_score: float = 0.3,
min_multimodal_score: float = 0.3,
filters: SearchFilters = {},
) -> list[SearchResult]:
"""
Search across knowledge base using text query and optional files.
@ -150,40 +338,45 @@ async def search(
Returns:
- List of search results sorted by score
"""
client = qdrant.get_qdrant_client()
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
text_results = query_chunks(
client,
data,
allowed_modalities & TEXT_COLLECTIONS,
embedding.embed_text,
min_score=min_text_score,
limit=limit,
)
multimodal_results = query_chunks(
client,
data,
allowed_modalities,
embedding.embed_mixed,
min_score=min_multimodal_score,
limit=limit,
)
search_results = {
k: text_results.get(k, []) + multimodal_results.get(k, [])
for k in allowed_modalities
}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
logger.error(f"Found chunks: {chunks}")
results = group_chunks(
[
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
text_embeddings_results = with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & TEXT_COLLECTIONS,
limit,
min_text_score,
filters,
multimodal=False,
)
)
multimodal_embeddings_results = with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & MULTIMODAL_COLLECTIONS,
limit,
min_multimodal_score,
filters,
multimodal=True,
)
)
bm25_results = with_timeout(
search_bm25(
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
modalities,
limit=limit,
filters=filters,
)
)
results = await asyncio.gather(
text_embeddings_results,
multimodal_embeddings_results,
bm25_results,
return_exceptions=False,
)
results = group_chunks([c for r in results for c in r])
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)

View File

@ -1,6 +1,7 @@
import logging
from typing import Literal, NotRequired, TypedDict
from PIL import Image
from memory.common import settings
@ -14,84 +15,92 @@ Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
model: str
on_disk: NotRequired[bool]
shards: NotRequired[int]
text: bool
multimodal: bool
ALL_COLLECTIONS: dict[str, Collection] = {
"mail": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
"chat": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": True,
},
"git": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
"book": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
"blog": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": True,
"multimodal": True,
},
"forum": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": True,
"multimodal": True,
},
"text": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
# Multimodal
"photo": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": False,
"multimodal": True,
},
"comic": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": False,
"multimodal": True,
},
"doc": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": False,
"multimodal": True,
},
# Observations
"semantic": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
"temporal": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
}
TEXT_COLLECTIONS = {
coll
for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.TEXT_EMBEDDING_MODEL
coll for coll, params in ALL_COLLECTIONS.items() if params.get("text")
}
MULTIMODAL_COLLECTIONS = {
coll
for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.MIXED_EMBEDDING_MODEL
coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal")
}
TYPES = {
@ -119,5 +128,12 @@ def get_modality(mime_type: str) -> str:
return "unknown"
def collection_model(collection: str) -> str | None:
return ALL_COLLECTIONS.get(collection, {}).get("model")
def collection_model(
collection: str, text: str, images: list[Image.Image]
) -> str | None:
config = ALL_COLLECTIONS.get(collection, {})
if images and config.get("multimodal"):
return settings.MIXED_EMBEDDING_MODEL
if text and config.get("text"):
return settings.TEXT_EMBEDDING_MODEL
return "unknown"

View File

@ -205,7 +205,9 @@ class SourceItem(Base):
id = Column(BigInteger, primary_key=True)
modality = Column(Text, nullable=False)
sha256 = Column(BYTEA, nullable=False, unique=True)
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
inserted_at = Column(
DateTime(timezone=True), server_default=func.now(), default=func.now()
)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
size = Column(Integer)
mime_type = Column(Text)
@ -261,14 +263,15 @@ class SourceItem(Base):
images = [c for c in data.data if isinstance(c, Image.Image)]
image_names = image_filenames(chunk_id, images)
modality = data.modality or cast(str, self.modality)
chunk = Chunk(
id=chunk_id,
source=self,
content=text or None,
images=images,
file_paths=image_names,
collection_name=data.collection_name or cast(str, self.modality),
embedding_model=collections.collection_model(cast(str, self.modality)),
collection_name=modality,
embedding_model=collections.collection_model(modality, text, images),
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
)
return chunk
@ -284,5 +287,8 @@ class SourceItem(Base):
}
@property
def display_contents(self) -> str | None:
return cast(str | None, self.content) or cast(str | None, self.filename)
def display_contents(self) -> str | dict | None:
return {
"tags": self.tags,
"size": self.size,
}

View File

@ -570,13 +570,23 @@ class AgentObservation(SourceItem):
"""Get all contradictions involving this observation."""
return self.contradictions_as_first + self.contradictions_as_second
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[extract.DataChunk]:
@property
def display_contents(self) -> dict:
return {
"subject": self.subject,
"content": self.content,
"observation_type": self.observation_type,
"evidence": self.evidence,
"confidence": self.confidence,
"agent_model": self.agent_model,
"tags": self.tags,
}
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
"""
Generate multiple chunks for different embedding dimensions.
Each chunk goes to a different Qdrant collection for specialized search.
"""
chunks = []
# 1. Semantic chunk - standard content representation
semantic_text = observation.generate_semantic_text(
cast(str, self.subject),
@ -584,12 +594,10 @@ class AgentObservation(SourceItem):
cast(str, self.content),
cast(observation.Evidence, self.evidence),
)
chunks.append(
extract.DataChunk(
data=[semantic_text],
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
collection_name="semantic",
)
semantic_chunk = extract.DataChunk(
data=[semantic_text],
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
modality="semantic",
)
# 2. Temporal chunk - time-aware representation
@ -599,14 +607,29 @@ class AgentObservation(SourceItem):
cast(float, self.confidence),
cast(datetime, self.inserted_at),
)
chunks.append(
extract.DataChunk(
data=[temporal_text],
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
collection_name="temporal",
)
temporal_chunk = extract.DataChunk(
data=[temporal_text],
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
modality="temporal",
)
others = [
self._make_chunk(
extract.DataChunk(
data=[i],
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
modality="semantic",
)
)
for i in [
self.content,
self.subject,
self.observation_type,
self.evidence.get("quote", ""),
]
if i
]
# TODO: Add more embedding dimensions here:
# 3. Epistemic chunk - belief structure focused
# epistemic_text = self._generate_epistemic_text()
@ -632,4 +655,7 @@ class AgentObservation(SourceItem):
# collection_name="observations_relational"
# ))
return chunks
return [
self._make_chunk(semantic_chunk),
self._make_chunk(temporal_chunk),
] + others

View File

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
def embed_chunks(
chunks: list[str] | list[list[extract.MulitmodalChunk]],
chunks: list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
) -> list[Vector]:
@ -24,52 +24,50 @@ def embed_chunks(
vo = voyageai.Client() # type: ignore
if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed(
chunks, # type: ignore
chunks,
model=model,
input_type=input_type,
).embeddings
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks]
return cast(
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
)
def break_chunk(
chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS
) -> list[extract.MulitmodalChunk]:
result = []
for c in chunk:
if isinstance(c, str):
result += chunk_text(c, chunk_size, OVERLAP_TOKENS)
else:
result.append(chunk)
return result
def embed_text(
texts: list[str],
chunks: list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]:
chunks = [
c
for text in texts
if isinstance(text, str)
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
if c.strip()
]
if not chunks:
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks]
if not any(chunked_chunks):
return []
try:
return embed_chunks(chunks, model, input_type)
except voyageai.error.InvalidRequestError as e: # type: ignore
logger.error(f"Error embedding text: {e}")
logger.debug(f"Text: {texts}")
raise
return embed_chunks(chunked_chunks, model, input_type)
def embed_mixed(
items: list[extract.MulitmodalChunk],
items: list[list[extract.MulitmodalChunk]],
model: str = settings.MIXED_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
if isinstance(item, str):
return [
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
]
return [item]
chunks = [c for item in items for c in to_chunks(item)]
return embed_chunks([chunks], model, input_type)
chunked_chunks = [break_chunk(item, chunk_size) for item in items]
return embed_chunks(chunked_chunks, model, input_type)
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
@ -87,6 +85,9 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
def embed_source_item(item: SourceItem) -> list[Chunk]:
chunks = list(item.data_chunks())
logger.error(
f"Embedding source item: {item.id} - {[(c.embedding_model, c.collection_name, c.chunks) for c in chunks]}"
)
if not chunks:
return []

View File

@ -21,7 +21,7 @@ class DataChunk:
data: Sequence[MulitmodalChunk]
metadata: dict[str, Any] = field(default_factory=dict)
mime_type: str = "text/plain"
collection_name: str | None = None
modality: str | None = None
@contextmanager

View File

@ -8,7 +8,7 @@ class Evidence(TypedDict):
def generate_semantic_text(
subject: str, observation_type: str, content: str, evidence: Evidence
subject: str, observation_type: str, content: str, evidence: Evidence | None = None
) -> str:
"""Generate text optimized for semantic similarity search."""
parts = [
@ -54,8 +54,9 @@ def generate_temporal_text(
f"Time: {time_of_day} on {day_of_week} ({time_period})",
f"Subject: {subject}",
f"Observation: {content}",
f"Confidence: {confidence}",
]
if confidence:
parts.append(f"Confidence: {confidence}")
return " | ".join(parts)

View File

@ -9,6 +9,8 @@ logger = logging.getLogger(__name__)
TAGS_PROMPT = """
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
Return your response as JSON with this format:
{{
"summary": "{summary}",
@ -23,6 +25,8 @@ SUMMARY_PROMPT = """
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
Also provide 3-5 relevant tags that capture the main topics or themes.
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
Return your response as JSON with this format:
{{
"summary": "your summary here",

View File

@ -1,62 +0,0 @@
import argparse
import logging
from typing import Any
from fastapi import UploadFile
import httpx
from mcp.server.fastmcp import FastMCP
SERVER = "http://localhost:8000"
logger = logging.getLogger(__name__)
mcp = FastMCP("memory")
async def make_request(
path: str,
method: str,
data: dict | None = None,
json: dict | None = None,
files: list[UploadFile] | None = None,
) -> httpx.Response:
async with httpx.AsyncClient() as client:
return await client.request(
method,
f"{SERVER}/{path}",
data=data,
json=json,
files=files, # type: ignore
)
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
return await make_request(path, "POST", data=data)
@mcp.tool()
async def search(
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
) -> list[dict[str, Any]]:
logger.error(f"Searching for {query}")
resp = await post_data(
"search",
{
"query": query,
"previews": previews,
"modalities": modalities,
"limit": limit,
},
)
return resp.json()
if __name__ == "__main__":
# Initialize and run the server
args = argparse.ArgumentParser()
args.add_argument("--server", type=str)
args = args.parse_args()
SERVER = args.server
mcp.run(transport=args.transport)