mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
initial memory tools
This commit is contained in:
parent
1dd93929c1
commit
b10a1fb130
@ -3,3 +3,4 @@ uvicorn==0.29.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
sqladmin
|
||||
mcp==1.9.2
|
@ -5,4 +5,6 @@ alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.18.1
|
||||
anthropic==0.18.1
|
||||
|
||||
bm25s[full]==0.2.13
|
654
src/memory/api/MCP/tools.py
Normal file
654
src/memory/api/MCP/tools.py
Normal file
@ -0,0 +1,654 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from hashlib import sha256
|
||||
import logging
|
||||
import uuid
|
||||
from typing import cast
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from memory.common.db.models.source_item import SourceItem
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy import func, cast as sql_cast, Text
|
||||
from memory.common.db.connection import make_session
|
||||
|
||||
from memory.common import extract
|
||||
from memory.common.db.models import AgentObservation
|
||||
from memory.api.search import search, SearchFilters
|
||||
from memory.common.formatters import observation
|
||||
from memory.workers.tasks.content_processing import process_content_item
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create MCP server instance
|
||||
mcp = FastMCP("memory", stateless=True)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""
|
||||
Get all unique tags used across the entire knowledge base.
|
||||
|
||||
Purpose:
|
||||
This tool retrieves all tags that have been used in the system, both from
|
||||
AI observations (created with 'observe') and other content. Use it to
|
||||
understand the tag taxonomy, ensure consistency, or discover related topics.
|
||||
|
||||
When to use:
|
||||
- Before creating new observations, to use consistent tag naming
|
||||
- To explore what topics/contexts have been tracked
|
||||
- To build tag filters for search operations
|
||||
- To understand the user's areas of interest
|
||||
- For tag autocomplete or suggestion features
|
||||
|
||||
Returns:
|
||||
Sorted list of all unique tags in the system. Tags follow patterns like:
|
||||
- Topics: "machine-learning", "functional-programming"
|
||||
- Projects: "project:website-redesign"
|
||||
- Contexts: "context:work", "context:late-night"
|
||||
- Domains: "domain:finance"
|
||||
|
||||
Example:
|
||||
# Get all tags to ensure consistency
|
||||
tags = await get_all_tags()
|
||||
# Returns: ["ai-safety", "context:work", "functional-programming",
|
||||
# "machine-learning", "project:thesis", ...]
|
||||
|
||||
# Use to check if a topic has been discussed before
|
||||
if "quantum-computing" in tags:
|
||||
# Search for related observations
|
||||
observations = await search_observations(
|
||||
query="quantum computing",
|
||||
tags=["quantum-computing"]
|
||||
)
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""
|
||||
Get all unique subjects from observations about the user.
|
||||
|
||||
Purpose:
|
||||
This tool retrieves all subject identifiers that have been used in
|
||||
observations (created with 'observe'). Subjects are the consistent
|
||||
identifiers for what observations are about. Use this to understand
|
||||
what aspects of the user have been tracked and ensure consistency.
|
||||
|
||||
When to use:
|
||||
- Before creating new observations, to use existing subject names
|
||||
- To discover what aspects of the user have been observed
|
||||
- To build subject filters for targeted searches
|
||||
- To ensure consistent naming across observations
|
||||
- To get an overview of the user model
|
||||
|
||||
Returns:
|
||||
Sorted list of all unique subjects. Common patterns include:
|
||||
- "programming_style", "programming_philosophy"
|
||||
- "work_habits", "work_schedule"
|
||||
- "ai_beliefs", "ai_safety_beliefs"
|
||||
- "learning_preferences"
|
||||
- "communication_style"
|
||||
|
||||
Example:
|
||||
# Get all subjects to ensure consistency
|
||||
subjects = await get_all_subjects()
|
||||
# Returns: ["ai_safety_beliefs", "architecture_preferences",
|
||||
# "programming_philosophy", "work_schedule", ...]
|
||||
|
||||
# Use to check what we know about the user
|
||||
if "programming_style" in subjects:
|
||||
# Get all programming-related observations
|
||||
observations = await search_observations(
|
||||
query="programming",
|
||||
subject="programming_style"
|
||||
)
|
||||
|
||||
Best practices:
|
||||
- Always check existing subjects before creating new ones
|
||||
- Use snake_case for consistency
|
||||
- Be specific but not too granular
|
||||
- Group related observations under same subject
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""
|
||||
Get all unique observation types that have been used.
|
||||
|
||||
Purpose:
|
||||
This tool retrieves the distinct observation types that have been recorded
|
||||
in the system. While the standard types are predefined (belief, preference,
|
||||
behavior, contradiction, general), this shows what's actually been used.
|
||||
Helpful for understanding the distribution of observation types.
|
||||
|
||||
When to use:
|
||||
- To see what types of observations have been made
|
||||
- To understand the balance of different observation types
|
||||
- To check if all standard types are being utilized
|
||||
- For analytics or reporting on observation patterns
|
||||
|
||||
Standard types:
|
||||
- "belief": Opinions or beliefs the user holds
|
||||
- "preference": Things they prefer or favor
|
||||
- "behavior": Patterns in how they act or work
|
||||
- "contradiction": Noted inconsistencies
|
||||
- "general": Observations that don't fit other categories
|
||||
|
||||
Returns:
|
||||
List of observation types that have actually been used in the system.
|
||||
|
||||
Example:
|
||||
# Check what types of observations exist
|
||||
types = await get_all_observation_types()
|
||||
# Returns: ["behavior", "belief", "contradiction", "preference"]
|
||||
|
||||
# Use to analyze observation distribution
|
||||
for obs_type in types:
|
||||
observations = await search_observations(
|
||||
query="",
|
||||
observation_types=[obs_type],
|
||||
limit=100
|
||||
)
|
||||
print(f"{obs_type}: {len(observations)} observations")
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search through the user's stored knowledge and content.
|
||||
|
||||
Purpose:
|
||||
This tool searches the user's personal knowledge base - a collection of
|
||||
their saved content including emails, documents, blog posts, books, and
|
||||
more. Use this alongside 'search_observations' to build a complete picture:
|
||||
- search_knowledge_base: Finds user's actual content and information
|
||||
- search_observations: Finds AI-generated insights about the user
|
||||
Together they enable deeply personalized, context-aware assistance.
|
||||
|
||||
When to use:
|
||||
- User asks about something they've read/written/received
|
||||
- You need to find specific content the user has saved
|
||||
- User references a document, email, or article
|
||||
- To provide quotes or information from user's sources
|
||||
- To understand context from user's past communications
|
||||
- When user says "that article about..." or similar references
|
||||
|
||||
How it works:
|
||||
Uses hybrid search combining semantic understanding with keyword matching.
|
||||
This means it finds content based on meaning AND specific terms, giving
|
||||
you the best of both approaches. Results are ranked by relevance.
|
||||
|
||||
Args:
|
||||
query: Natural language search query. Be descriptive about what you're
|
||||
looking for. The search understands meaning but also values exact terms.
|
||||
Examples:
|
||||
- "email about project deadline from last week"
|
||||
- "functional programming articles comparing Haskell and Scala"
|
||||
- "that blog post about AI safety and alignment"
|
||||
- "recipe for chocolate cake Sarah sent me"
|
||||
Pro tip: Include both concepts and specific keywords for best results.
|
||||
|
||||
previews: Whether to include content snippets in results.
|
||||
- True: Returns preview text and image previews (useful for quick scanning)
|
||||
- False: Returns just metadata (faster, less data)
|
||||
Default is False.
|
||||
|
||||
modalities: Types of content to search. Leave empty to search all.
|
||||
Available types:
|
||||
- 'email': Email messages
|
||||
- 'blog': Blog posts and articles
|
||||
- 'book': Book sections and ebooks
|
||||
- 'forum': Forum posts (e.g., LessWrong, Reddit)
|
||||
- 'observation': AI observations (use search_observations instead)
|
||||
- 'photo': Images with extracted text
|
||||
- 'comic': Comics and graphic content
|
||||
- 'webpage': General web pages
|
||||
Examples:
|
||||
- ["email"] - only emails
|
||||
- ["blog", "forum"] - articles and forum posts
|
||||
- [] - search everything
|
||||
|
||||
limit: Maximum results to return (1-100). Default 10.
|
||||
Increase for comprehensive searches, decrease for quick lookups.
|
||||
|
||||
Returns:
|
||||
List of search results ranked by relevance, each containing:
|
||||
- id: Unique identifier for the source item
|
||||
- score: Relevance score (0-1, higher is better)
|
||||
- chunks: Matching content segments with metadata
|
||||
- content: Full details including:
|
||||
- For emails: sender, recipient, subject, date
|
||||
- For blogs: author, title, url, publish date
|
||||
- For books: title, author, chapter info
|
||||
- Type-specific fields for each modality
|
||||
- filename: Path to file if content is stored on disk
|
||||
|
||||
Examples:
|
||||
# Find specific email
|
||||
results = await search_knowledge_base(
|
||||
query="Sarah deadline project proposal next Friday",
|
||||
modalities=["email"],
|
||||
previews=True,
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Search for technical articles
|
||||
results = await search_knowledge_base(
|
||||
query="functional programming monads category theory",
|
||||
modalities=["blog", "book"],
|
||||
limit=20
|
||||
)
|
||||
|
||||
# Find everything about a topic
|
||||
results = await search_knowledge_base(
|
||||
query="machine learning deployment kubernetes docker",
|
||||
previews=True
|
||||
)
|
||||
|
||||
# Quick lookup of a remembered document
|
||||
results = await search_knowledge_base(
|
||||
query="tax forms 2023 accountant recommendations",
|
||||
modalities=["email"],
|
||||
limit=3
|
||||
)
|
||||
|
||||
Best practices:
|
||||
- Include context in queries ("email from Sarah" vs just "Sarah")
|
||||
- Use modalities to filter when you know the content type
|
||||
- Enable previews when you need to verify content before using
|
||||
- Combine with search_observations for complete context
|
||||
- Higher scores (>0.7) indicate strong matches
|
||||
- If no results, try broader queries or different phrasing
|
||||
"""
|
||||
logger.info(f"MCP search for: {query}")
|
||||
|
||||
upload_data = extract.extract_text(query)
|
||||
results = await search(
|
||||
upload_data,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
limit=limit,
|
||||
min_text_score=0.3,
|
||||
min_multimodal_score=0.3,
|
||||
)
|
||||
|
||||
# Convert SearchResult objects to dictionaries for MCP
|
||||
return [result.model_dump() for result in results]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def observe(
|
||||
content: str,
|
||||
subject: str,
|
||||
observation_type: str = "general",
|
||||
confidence: float = 0.8,
|
||||
evidence: dict | None = None,
|
||||
tags: list[str] | None = None,
|
||||
session_id: str | None = None,
|
||||
agent_model: str = "unknown",
|
||||
) -> dict:
|
||||
"""
|
||||
Record an observation about the user to build long-term understanding.
|
||||
|
||||
Purpose:
|
||||
This tool is part of a memory system designed to help AI agents build a
|
||||
deep, persistent understanding of users over time. Use it to record any
|
||||
notable information about the user's preferences, beliefs, behaviors, or
|
||||
characteristics. These observations accumulate to create a comprehensive
|
||||
model of the user that improves future interactions.
|
||||
|
||||
Quick Reference:
|
||||
# Most common patterns:
|
||||
observe(content="User prefers X over Y because...", subject="preferences", observation_type="preference")
|
||||
observe(content="User always/often does X when Y", subject="work_habits", observation_type="behavior")
|
||||
observe(content="User believes/thinks X about Y", subject="beliefs_on_topic", observation_type="belief")
|
||||
observe(content="User said X but previously said Y", subject="topic", observation_type="contradiction")
|
||||
|
||||
When to use:
|
||||
- User expresses a preference or opinion
|
||||
- You notice a behavioral pattern
|
||||
- User reveals information about their work/life/interests
|
||||
- You spot a contradiction with previous statements
|
||||
- Any insight that would help understand the user better in future
|
||||
|
||||
Important: Be an active observer. Don't wait to be asked - proactively record
|
||||
observations throughout conversations to build understanding.
|
||||
|
||||
Args:
|
||||
content: The observation itself. Be specific and detailed. Write complete
|
||||
thoughts that will make sense when read months later without context.
|
||||
Bad: "Likes FP"
|
||||
Good: "User strongly prefers functional programming paradigms, especially
|
||||
pure functions and immutability, considering them more maintainable"
|
||||
|
||||
subject: A consistent identifier for what this observation is about. Use
|
||||
snake_case and be consistent across observations to enable tracking.
|
||||
Examples:
|
||||
- "programming_style" (not "coding" or "development")
|
||||
- "work_habits" (not "productivity" or "work_patterns")
|
||||
- "ai_safety_beliefs" (not "AI" or "artificial_intelligence")
|
||||
|
||||
observation_type: Categorize the observation:
|
||||
- "belief": An opinion or belief the user holds
|
||||
- "preference": Something they prefer or favor
|
||||
- "behavior": A pattern in how they act or work
|
||||
- "contradiction": An inconsistency with previous observations
|
||||
- "general": Doesn't fit other categories
|
||||
|
||||
confidence: How certain you are (0.0-1.0):
|
||||
- 1.0: User explicitly stated this
|
||||
- 0.9: Strongly implied or demonstrated repeatedly
|
||||
- 0.8: Inferred with high confidence (default)
|
||||
- 0.7: Probable but with some uncertainty
|
||||
- 0.6 or below: Speculative, use sparingly
|
||||
|
||||
evidence: Supporting context as a dict. Include relevant details:
|
||||
- "quote": Exact words from the user
|
||||
- "context": What prompted this observation
|
||||
- "timestamp": When this was observed
|
||||
- "related_to": Connection to other topics
|
||||
Example: {
|
||||
"quote": "I always refactor to pure functions",
|
||||
"context": "Discussing code review practices"
|
||||
}
|
||||
|
||||
tags: Categorization labels. Use lowercase with hyphens. Common patterns:
|
||||
- Topics: "machine-learning", "web-development", "philosophy"
|
||||
- Projects: "project:website-redesign", "project:thesis"
|
||||
- Contexts: "context:work", "context:personal", "context:late-night"
|
||||
- Domains: "domain:finance", "domain:healthcare"
|
||||
|
||||
session_id: UUID string to group observations from the same conversation.
|
||||
Generate one UUID per conversation and reuse it for all observations
|
||||
in that conversation. Format: "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
agent_model: Which AI model made this observation (e.g., "claude-3-opus",
|
||||
"gpt-4", "claude-3.5-sonnet"). Helps track observation quality.
|
||||
|
||||
Returns:
|
||||
Dict with created observation details:
|
||||
- id: Unique identifier for reference
|
||||
- created_at: Timestamp of creation
|
||||
- subject: The subject as stored
|
||||
- observation_type: The type as stored
|
||||
- confidence: The confidence score
|
||||
- tags: List of applied tags
|
||||
|
||||
Examples:
|
||||
# After user mentions their coding philosophy
|
||||
await observe(
|
||||
content="User believes strongly in functional programming principles, "
|
||||
"particularly avoiding mutable state which they call 'the root "
|
||||
"of all evil'. They prioritize code purity over performance.",
|
||||
subject="programming_philosophy",
|
||||
observation_type="belief",
|
||||
confidence=0.95,
|
||||
evidence={
|
||||
"quote": "State is the root of all evil in programming",
|
||||
"context": "Discussing why they chose Haskell for their project"
|
||||
},
|
||||
tags=["programming", "functional-programming", "philosophy"],
|
||||
session_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
agent_model="claude-3-opus"
|
||||
)
|
||||
|
||||
# Noticing a work pattern
|
||||
await observe(
|
||||
content="User frequently works on complex problems late at night, "
|
||||
"typically between 11pm and 3am, claiming better focus",
|
||||
subject="work_schedule",
|
||||
observation_type="behavior",
|
||||
confidence=0.85,
|
||||
evidence={
|
||||
"context": "Mentioned across multiple conversations over 2 weeks"
|
||||
},
|
||||
tags=["behavior", "work-habits", "productivity", "context:late-night"],
|
||||
agent_model="claude-3-opus"
|
||||
)
|
||||
|
||||
# Recording a contradiction
|
||||
await observe(
|
||||
content="User now advocates for microservices architecture, but "
|
||||
"previously argued strongly for monoliths in similar contexts",
|
||||
subject="architecture_preferences",
|
||||
observation_type="contradiction",
|
||||
confidence=0.9,
|
||||
evidence={
|
||||
"quote": "Microservices are definitely the way to go",
|
||||
"context": "Designing a new system similar to one from 3 months ago"
|
||||
},
|
||||
tags=["architecture", "contradiction", "software-design"],
|
||||
agent_model="gpt-4"
|
||||
)
|
||||
"""
|
||||
# Create the observation
|
||||
observation = AgentObservation(
|
||||
content=content,
|
||||
subject=subject,
|
||||
observation_type=observation_type,
|
||||
confidence=confidence,
|
||||
evidence=evidence,
|
||||
tags=tags or [],
|
||||
session_id=uuid.UUID(session_id) if session_id else None,
|
||||
agent_model=agent_model,
|
||||
size=len(content),
|
||||
mime_type="text/plain",
|
||||
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(),
|
||||
inserted_at=datetime.now(timezone.utc),
|
||||
)
|
||||
try:
|
||||
with make_session() as session:
|
||||
process_content_item(observation, session)
|
||||
|
||||
if not observation.id:
|
||||
raise ValueError("Observation not created")
|
||||
|
||||
logger.info(
|
||||
f"Observation created: {observation.id}, {observation.inserted_at}"
|
||||
)
|
||||
return {
|
||||
"id": observation.id,
|
||||
"created_at": observation.inserted_at.isoformat(),
|
||||
"subject": observation.subject,
|
||||
"observation_type": observation.observation_type,
|
||||
"confidence": cast(float, observation.confidence),
|
||||
"tags": observation.tags,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating observation: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_observations(
|
||||
query: str,
|
||||
subject: str = "",
|
||||
tags: list[str] | None = None,
|
||||
observation_types: list[str] | None = None,
|
||||
min_confidence: float = 0.5,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search through observations to understand the user better.
|
||||
|
||||
Purpose:
|
||||
This tool searches through all observations recorded about the user using
|
||||
the 'observe' tool. Use it to recall past insights, check for patterns,
|
||||
find contradictions, or understand the user's preferences before responding.
|
||||
The more you use this tool, the more personalized and insightful your
|
||||
responses can be.
|
||||
|
||||
When to use:
|
||||
- Before answering questions where user preferences might matter
|
||||
- When the user references something from the past
|
||||
- To check if current behavior aligns with past patterns
|
||||
- To find related observations on a topic
|
||||
- To build context about the user's expertise or interests
|
||||
- Whenever personalization would improve your response
|
||||
|
||||
How it works:
|
||||
Uses hybrid search combining semantic similarity with keyword matching.
|
||||
Searches across multiple embedding spaces (semantic meaning and temporal
|
||||
context) to find relevant observations from different angles. This approach
|
||||
ensures you find both conceptually related and specifically mentioned items.
|
||||
|
||||
Args:
|
||||
query: Natural language description of what you're looking for. The search
|
||||
matches both meaning and specific terms in observation content.
|
||||
Examples:
|
||||
- "programming preferences and coding style"
|
||||
- "opinions about artificial intelligence and AI safety"
|
||||
- "work habits productivity patterns when does user work best"
|
||||
- "previous projects the user has worked on"
|
||||
Pro tip: Use natural language but include key terms you expect to find.
|
||||
|
||||
subject: Filter by exact subject identifier. Must match subjects used when
|
||||
creating observations (e.g., "programming_style", "work_habits").
|
||||
Leave empty to search all subjects. Use this when you know the exact
|
||||
subject category you want.
|
||||
|
||||
tags: Filter results to only observations with these tags. Observations must
|
||||
have at least one matching tag. Use the same format as when creating:
|
||||
- ["programming", "functional-programming"]
|
||||
- ["context:work", "project:thesis"]
|
||||
- ["domain:finance", "machine-learning"]
|
||||
|
||||
observation_types: Filter by type of observation:
|
||||
- "belief": Opinions or beliefs the user holds
|
||||
- "preference": Things they prefer or favor
|
||||
- "behavior": Patterns in how they act or work
|
||||
- "contradiction": Noted inconsistencies
|
||||
- "general": Other observations
|
||||
Leave as None to search all types.
|
||||
|
||||
min_confidence: Only return observations with confidence >= this value.
|
||||
- Use 0.8+ for high-confidence facts
|
||||
- Use 0.5-0.7 to include inferred observations
|
||||
- Default 0.5 includes most observations
|
||||
Range: 0.0 to 1.0
|
||||
|
||||
limit: Maximum results to return (1-100). Default 10. Increase when you
|
||||
need comprehensive understanding of a topic.
|
||||
|
||||
Returns:
|
||||
List of observations sorted by relevance, each containing:
|
||||
- subject: What the observation is about
|
||||
- content: The full observation text
|
||||
- observation_type: Type of observation
|
||||
- evidence: Supporting context/quotes if provided
|
||||
- confidence: How certain the observation is (0-1)
|
||||
- agent_model: Which AI model made the observation
|
||||
- tags: All tags on this observation
|
||||
- created_at: When it was observed (if available)
|
||||
|
||||
Examples:
|
||||
# Before discussing code architecture
|
||||
results = await search_observations(
|
||||
query="software architecture preferences microservices monoliths",
|
||||
tags=["architecture"],
|
||||
min_confidence=0.7
|
||||
)
|
||||
|
||||
# Understanding work style for scheduling
|
||||
results = await search_observations(
|
||||
query="when does user work best productivity schedule",
|
||||
observation_types=["behavior", "preference"],
|
||||
subject="work_schedule"
|
||||
)
|
||||
|
||||
# Check for AI safety views before discussing AI
|
||||
results = await search_observations(
|
||||
query="artificial intelligence safety alignment concerns",
|
||||
observation_types=["belief"],
|
||||
min_confidence=0.8,
|
||||
limit=20
|
||||
)
|
||||
|
||||
# Find contradictions on a topic
|
||||
results = await search_observations(
|
||||
query="testing methodology unit tests integration",
|
||||
observation_types=["contradiction"],
|
||||
tags=["testing", "software-development"]
|
||||
)
|
||||
|
||||
Best practices:
|
||||
- Search before making assumptions about user preferences
|
||||
- Use broad queries first, then filter with tags/types if too many results
|
||||
- Check for contradictions when user says something unexpected
|
||||
- Higher confidence observations are more reliable
|
||||
- Recent observations may override older ones on same topic
|
||||
"""
|
||||
source_ids = None
|
||||
if tags or observation_types:
|
||||
with make_session() as session:
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if observation_types:
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.observation_type.in_(observation_types)
|
||||
)
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
semantic_text = observation.generate_semantic_text(
|
||||
subject=subject or "",
|
||||
observation_type="".join(observation_types or []),
|
||||
content=query,
|
||||
evidence=None,
|
||||
)
|
||||
temporal = observation.generate_temporal_text(
|
||||
subject=subject or "",
|
||||
content=query,
|
||||
confidence=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
results = await search(
|
||||
[
|
||||
extract.DataChunk(data=[query]),
|
||||
extract.DataChunk(data=[semantic_text]),
|
||||
extract.DataChunk(data=[temporal]),
|
||||
],
|
||||
previews=True,
|
||||
modalities=["semantic", "temporal"],
|
||||
limit=limit,
|
||||
min_text_score=0.8,
|
||||
filters=SearchFilters(
|
||||
subject=subject,
|
||||
confidence=min_confidence,
|
||||
tags=tags,
|
||||
observation_types=observation_types,
|
||||
source_ids=source_ids,
|
||||
),
|
||||
)
|
||||
|
||||
return [
|
||||
cast(dict, cast(dict, result.model_dump()).get("content")) for result in results
|
||||
]
|
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
ForumPost,
|
||||
AgentObservation,
|
||||
)
|
||||
|
||||
|
||||
@ -53,6 +54,7 @@ class SourceItemAdmin(ModelView, model=SourceItem):
|
||||
|
||||
class ChunkAdmin(ModelView, model=Chunk):
|
||||
column_list = ["id", "source_id", "embedding_model", "created_at"]
|
||||
column_sortable_list = ["created_at"]
|
||||
|
||||
|
||||
class MailMessageAdmin(ModelView, model=MailMessage):
|
||||
@ -174,9 +176,23 @@ class EmailAccountAdmin(ModelView, model=EmailAccount):
|
||||
column_searchable_list = ["name", "email_address"]
|
||||
|
||||
|
||||
class AgentObservationAdmin(ModelView, model=AgentObservation):
|
||||
column_list = [
|
||||
"id",
|
||||
"content",
|
||||
"subject",
|
||||
"observation_type",
|
||||
"confidence",
|
||||
"evidence",
|
||||
"inserted_at",
|
||||
]
|
||||
column_searchable_list = ["subject", "observation_type"]
|
||||
|
||||
|
||||
def setup_admin(admin: Admin):
|
||||
"""Add all admin views to the admin instance."""
|
||||
admin.add_view(SourceItemAdmin)
|
||||
admin.add_view(AgentObservationAdmin)
|
||||
admin.add_view(ChunkAdmin)
|
||||
admin.add_view(EmailAccountAdmin)
|
||||
admin.add_view(MailMessageAdmin)
|
||||
|
@ -2,6 +2,7 @@
|
||||
FastAPI application for the knowledge base.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import pathlib
|
||||
import logging
|
||||
from typing import Annotated, Optional
|
||||
@ -15,15 +16,25 @@ from memory.common import extract
|
||||
from memory.common.db.connection import get_engine
|
||||
from memory.api.admin import setup_admin
|
||||
from memory.api.search import search, SearchResult
|
||||
from memory.api.MCP.tools import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Knowledge Base API")
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(mcp.session_manager.run())
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
|
||||
# SQLAdmin setup
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
setup_admin(admin)
|
||||
app.mount("/", mcp.streamable_http_app())
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@ -104,4 +115,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from memory.common.qdrant import setup_qdrant
|
||||
|
||||
setup_qdrant()
|
||||
main()
|
||||
|
@ -2,21 +2,29 @@
|
||||
Search endpoints for the knowledge base API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from hashlib import sha256
|
||||
import io
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, TypedDict, NotRequired
|
||||
|
||||
import bm25s
|
||||
import Stemmer
|
||||
import qdrant_client
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
|
||||
from memory.common import embedding, qdrant, extract, settings
|
||||
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
|
||||
from memory.common import embedding, extract, qdrant, settings
|
||||
from memory.common.collections import (
|
||||
ALL_COLLECTIONS,
|
||||
MULTIMODAL_COLLECTIONS,
|
||||
TEXT_COLLECTIONS,
|
||||
)
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.db.models import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -28,6 +36,17 @@ class AnnotatedChunk(BaseModel):
|
||||
preview: Optional[str | None] = None
|
||||
|
||||
|
||||
class SourceData(BaseModel):
|
||||
"""Holds source item data to avoid SQLAlchemy session issues"""
|
||||
|
||||
id: int
|
||||
size: int | None
|
||||
mime_type: str | None
|
||||
filename: str | None
|
||||
content: str | dict | None
|
||||
content_length: int
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
collection: str
|
||||
results: list[dict]
|
||||
@ -38,16 +57,46 @@ class SearchResult(BaseModel):
|
||||
size: int
|
||||
mime_type: str
|
||||
chunks: list[AnnotatedChunk]
|
||||
content: Optional[str] = None
|
||||
content: Optional[str | dict] = None
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
class SearchFilters(TypedDict):
|
||||
subject: NotRequired[str | None]
|
||||
confidence: NotRequired[float]
|
||||
tags: NotRequired[list[str] | None]
|
||||
observation_types: NotRequired[list[str] | None]
|
||||
source_ids: NotRequired[list[int] | None]
|
||||
|
||||
|
||||
async def with_timeout(
|
||||
call, timeout: int = 2
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
Run a function with a timeout.
|
||||
|
||||
Args:
|
||||
call: The function to run
|
||||
timeout: The timeout in seconds
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(call, timeout=timeout)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Search timed out after {timeout}s")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def annotated_chunk(
|
||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||
) -> tuple[SourceItem, AnnotatedChunk]:
|
||||
) -> tuple[SourceData, AnnotatedChunk]:
|
||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||
if not previews and not isinstance(item, str):
|
||||
return None
|
||||
if not previews and isinstance(item, str):
|
||||
return item[:100]
|
||||
|
||||
if isinstance(item, Image.Image):
|
||||
buffer = io.BytesIO()
|
||||
@ -68,7 +117,19 @@ def annotated_chunk(
|
||||
for k, v in metadata.items()
|
||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||
}
|
||||
return chunk.source, AnnotatedChunk(
|
||||
|
||||
# Prefetch all needed source data while in session
|
||||
source = chunk.source
|
||||
source_data = SourceData(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
filename=source.filename,
|
||||
content=source.display_contents,
|
||||
content_length=len(source.content) if source.content else 0,
|
||||
)
|
||||
|
||||
return source_data, AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=search_result.score,
|
||||
metadata=metadata,
|
||||
@ -76,24 +137,28 @@ def annotated_chunk(
|
||||
)
|
||||
|
||||
|
||||
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
||||
def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]:
|
||||
items = defaultdict(list)
|
||||
source_lookup = {}
|
||||
|
||||
for source, chunk in chunks:
|
||||
items[source].append(chunk)
|
||||
items[source.id].append(chunk)
|
||||
source_lookup[source.id] = source
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
id=source.id,
|
||||
size=source.size or len(source.content),
|
||||
size=source.size or source.content_length,
|
||||
mime_type=source.mime_type or "text/plain",
|
||||
filename=source.filename
|
||||
and source.filename.replace(
|
||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||
),
|
||||
content=source.display_contents,
|
||||
content=source.content,
|
||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||
)
|
||||
for source, chunks in items.items()
|
||||
for source_id, chunks in items.items()
|
||||
for source in [source_lookup[source_id]]
|
||||
]
|
||||
|
||||
|
||||
@ -104,25 +169,31 @@ def query_chunks(
|
||||
embedder: Callable,
|
||||
min_score: float = 0.0,
|
||||
limit: int = 10,
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||
if not upload_data:
|
||||
if not upload_data or not allowed_modalities:
|
||||
return {}
|
||||
|
||||
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
||||
chunks = [chunk for chunk in upload_data if chunk.data]
|
||||
if not chunks:
|
||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||
return {}
|
||||
|
||||
vector = embedder(chunks, input_type="query")[0]
|
||||
logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}")
|
||||
for c in chunks:
|
||||
logger.error(f"Chunk: {c.data}")
|
||||
vectors = embedder([c.data for c in chunks], input_type="query")
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for vector in vectors
|
||||
for r in qdrant.search_vectors(
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
filter_params=filters,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
@ -130,6 +201,122 @@ def query_chunks(
|
||||
}
|
||||
|
||||
|
||||
async def search_bm25(
|
||||
query: str,
|
||||
modalities: list[str],
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
with make_session() as db:
|
||||
items_query = db.query(Chunk.id, Chunk.content).filter(
|
||||
Chunk.collection_name.in_(modalities)
|
||||
)
|
||||
if source_ids := filters.get("source_ids"):
|
||||
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||
items = items_query.all()
|
||||
item_ids = {
|
||||
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
|
||||
for item in items
|
||||
}
|
||||
corpus = [item.content.lower().strip() for item in items]
|
||||
|
||||
stemmer = Stemmer.Stemmer("english")
|
||||
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
|
||||
retriever = bm25s.BM25()
|
||||
retriever.index(corpus_tokens)
|
||||
|
||||
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
|
||||
results, scores = retriever.retrieve(
|
||||
query_tokens, k=min(limit, len(corpus)), corpus=corpus
|
||||
)
|
||||
|
||||
item_scores = {
|
||||
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
|
||||
for doc, score in zip(results[0], scores[0])
|
||||
}
|
||||
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
# Prefetch all needed source data while in session
|
||||
source = chunk.source
|
||||
source_data = SourceData(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
filename=source.filename,
|
||||
content=source.display_contents,
|
||||
content_length=len(source.content) if source.content else 0,
|
||||
)
|
||||
|
||||
annotated = AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=item_scores[chunk.id],
|
||||
metadata=source.as_payload(),
|
||||
preview=None,
|
||||
)
|
||||
results.append((source_data, annotated))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_embeddings(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
min_score: float = 0.3,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
multimodal: bool = False,
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- data: List of data to search in (e.g., text, images, files)
|
||||
- previews: Whether to include previews in the search results
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- limit: Maximum number of results
|
||||
- min_score: Minimum score to include in the search results
|
||||
- filters: Filters to apply to the search results
|
||||
- multimodal: Whether to search in multimodal collections
|
||||
"""
|
||||
query_filters = {
|
||||
"must": [
|
||||
{"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}},
|
||||
],
|
||||
}
|
||||
if tags := filters.get("tags"):
|
||||
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
|
||||
if observation_types := filters.get("observation_types"):
|
||||
query_filters["must"] += [
|
||||
{"key": "observation_type", "match": {"any": observation_types}}
|
||||
]
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
modalities,
|
||||
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||
min_score=min_score,
|
||||
limit=limit,
|
||||
filters=query_filters,
|
||||
)
|
||||
search_results = {k: results.get(k, []) for k in modalities}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
return [
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
|
||||
async def search(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
@ -137,6 +324,7 @@ async def search(
|
||||
limit: int = 10,
|
||||
min_text_score: float = 0.3,
|
||||
min_multimodal_score: float = 0.3,
|
||||
filters: SearchFilters = {},
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
@ -150,40 +338,45 @@ async def search(
|
||||
Returns:
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||
text_results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
embedding.embed_text,
|
||||
min_score=min_text_score,
|
||||
limit=limit,
|
||||
)
|
||||
multimodal_results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
allowed_modalities,
|
||||
embedding.embed_mixed,
|
||||
min_score=min_multimodal_score,
|
||||
limit=limit,
|
||||
)
|
||||
search_results = {
|
||||
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
||||
for k in allowed_modalities
|
||||
}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
logger.error(f"Found chunks: {chunks}")
|
||||
|
||||
results = group_chunks(
|
||||
[
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
||||
text_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
min_text_score,
|
||||
filters,
|
||||
multimodal=False,
|
||||
)
|
||||
)
|
||||
multimodal_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
min_multimodal_score,
|
||||
filters,
|
||||
multimodal=True,
|
||||
)
|
||||
)
|
||||
bm25_results = with_timeout(
|
||||
search_bm25(
|
||||
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
||||
modalities,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
text_embeddings_results,
|
||||
multimodal_embeddings_results,
|
||||
bm25_results,
|
||||
return_exceptions=False,
|
||||
)
|
||||
|
||||
results = group_chunks([c for r in results for c in r])
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
@ -14,84 +15,92 @@ Vector = list[float]
|
||||
class Collection(TypedDict):
|
||||
dimension: int
|
||||
distance: DistanceType
|
||||
model: str
|
||||
on_disk: NotRequired[bool]
|
||||
shards: NotRequired[int]
|
||||
text: bool
|
||||
multimodal: bool
|
||||
|
||||
|
||||
ALL_COLLECTIONS: dict[str, Collection] = {
|
||||
"mail": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
"chat": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": True,
|
||||
},
|
||||
"git": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
"book": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
"blog": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": True,
|
||||
},
|
||||
"forum": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": True,
|
||||
},
|
||||
"text": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
# Multimodal
|
||||
"photo": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": False,
|
||||
"multimodal": True,
|
||||
},
|
||||
"comic": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": False,
|
||||
"multimodal": True,
|
||||
},
|
||||
"doc": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": False,
|
||||
"multimodal": True,
|
||||
},
|
||||
# Observations
|
||||
"semantic": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
"temporal": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
}
|
||||
TEXT_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in ALL_COLLECTIONS.items()
|
||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
||||
coll for coll, params in ALL_COLLECTIONS.items() if params.get("text")
|
||||
}
|
||||
MULTIMODAL_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in ALL_COLLECTIONS.items()
|
||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
||||
coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal")
|
||||
}
|
||||
|
||||
TYPES = {
|
||||
@ -119,5 +128,12 @@ def get_modality(mime_type: str) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def collection_model(collection: str) -> str | None:
|
||||
return ALL_COLLECTIONS.get(collection, {}).get("model")
|
||||
def collection_model(
|
||||
collection: str, text: str, images: list[Image.Image]
|
||||
) -> str | None:
|
||||
config = ALL_COLLECTIONS.get(collection, {})
|
||||
if images and config.get("multimodal"):
|
||||
return settings.MIXED_EMBEDDING_MODEL
|
||||
if text and config.get("text"):
|
||||
return settings.TEXT_EMBEDDING_MODEL
|
||||
return "unknown"
|
||||
|
@ -205,7 +205,9 @@ class SourceItem(Base):
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
modality = Column(Text, nullable=False)
|
||||
sha256 = Column(BYTEA, nullable=False, unique=True)
|
||||
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
inserted_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), default=func.now()
|
||||
)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
size = Column(Integer)
|
||||
mime_type = Column(Text)
|
||||
@ -261,14 +263,15 @@ class SourceItem(Base):
|
||||
images = [c for c in data.data if isinstance(c, Image.Image)]
|
||||
image_names = image_filenames(chunk_id, images)
|
||||
|
||||
modality = data.modality or cast(str, self.modality)
|
||||
chunk = Chunk(
|
||||
id=chunk_id,
|
||||
source=self,
|
||||
content=text or None,
|
||||
images=images,
|
||||
file_paths=image_names,
|
||||
collection_name=data.collection_name or cast(str, self.modality),
|
||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||
collection_name=modality,
|
||||
embedding_model=collections.collection_model(modality, text, images),
|
||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||
)
|
||||
return chunk
|
||||
@ -284,5 +287,8 @@ class SourceItem(Base):
|
||||
}
|
||||
|
||||
@property
|
||||
def display_contents(self) -> str | None:
|
||||
return cast(str | None, self.content) or cast(str | None, self.filename)
|
||||
def display_contents(self) -> str | dict | None:
|
||||
return {
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
||||
|
@ -570,13 +570,23 @@ class AgentObservation(SourceItem):
|
||||
"""Get all contradictions involving this observation."""
|
||||
return self.contradictions_as_first + self.contradictions_as_second
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[extract.DataChunk]:
|
||||
@property
|
||||
def display_contents(self) -> dict:
|
||||
return {
|
||||
"subject": self.subject,
|
||||
"content": self.content,
|
||||
"observation_type": self.observation_type,
|
||||
"evidence": self.evidence,
|
||||
"confidence": self.confidence,
|
||||
"agent_model": self.agent_model,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
"""
|
||||
Generate multiple chunks for different embedding dimensions.
|
||||
Each chunk goes to a different Qdrant collection for specialized search.
|
||||
"""
|
||||
chunks = []
|
||||
|
||||
# 1. Semantic chunk - standard content representation
|
||||
semantic_text = observation.generate_semantic_text(
|
||||
cast(str, self.subject),
|
||||
@ -584,12 +594,10 @@ class AgentObservation(SourceItem):
|
||||
cast(str, self.content),
|
||||
cast(observation.Evidence, self.evidence),
|
||||
)
|
||||
chunks.append(
|
||||
extract.DataChunk(
|
||||
data=[semantic_text],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||
collection_name="semantic",
|
||||
)
|
||||
semantic_chunk = extract.DataChunk(
|
||||
data=[semantic_text],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||
modality="semantic",
|
||||
)
|
||||
|
||||
# 2. Temporal chunk - time-aware representation
|
||||
@ -599,14 +607,29 @@ class AgentObservation(SourceItem):
|
||||
cast(float, self.confidence),
|
||||
cast(datetime, self.inserted_at),
|
||||
)
|
||||
chunks.append(
|
||||
extract.DataChunk(
|
||||
data=[temporal_text],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
|
||||
collection_name="temporal",
|
||||
)
|
||||
temporal_chunk = extract.DataChunk(
|
||||
data=[temporal_text],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
|
||||
modality="temporal",
|
||||
)
|
||||
|
||||
others = [
|
||||
self._make_chunk(
|
||||
extract.DataChunk(
|
||||
data=[i],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||
modality="semantic",
|
||||
)
|
||||
)
|
||||
for i in [
|
||||
self.content,
|
||||
self.subject,
|
||||
self.observation_type,
|
||||
self.evidence.get("quote", ""),
|
||||
]
|
||||
if i
|
||||
]
|
||||
|
||||
# TODO: Add more embedding dimensions here:
|
||||
# 3. Epistemic chunk - belief structure focused
|
||||
# epistemic_text = self._generate_epistemic_text()
|
||||
@ -632,4 +655,7 @@ class AgentObservation(SourceItem):
|
||||
# collection_name="observations_relational"
|
||||
# ))
|
||||
|
||||
return chunks
|
||||
return [
|
||||
self._make_chunk(semantic_chunk),
|
||||
self._make_chunk(temporal_chunk),
|
||||
] + others
|
||||
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def embed_chunks(
|
||||
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
||||
chunks: list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
@ -24,52 +24,50 @@ def embed_chunks(
|
||||
vo = voyageai.Client() # type: ignore
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
chunks, # type: ignore
|
||||
chunks,
|
||||
model=model,
|
||||
input_type=input_type,
|
||||
).embeddings
|
||||
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
|
||||
|
||||
texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks]
|
||||
return cast(
|
||||
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
|
||||
)
|
||||
|
||||
|
||||
def break_chunk(
|
||||
chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS
|
||||
) -> list[extract.MulitmodalChunk]:
|
||||
result = []
|
||||
for c in chunk:
|
||||
if isinstance(c, str):
|
||||
result += chunk_text(c, chunk_size, OVERLAP_TOKENS)
|
||||
else:
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
chunks: list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
chunks = [
|
||||
c
|
||||
for text in texts
|
||||
if isinstance(text, str)
|
||||
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
|
||||
if c.strip()
|
||||
]
|
||||
if not chunks:
|
||||
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks]
|
||||
if not any(chunked_chunks):
|
||||
return []
|
||||
|
||||
try:
|
||||
return embed_chunks(chunks, model, input_type)
|
||||
except voyageai.error.InvalidRequestError as e: # type: ignore
|
||||
logger.error(f"Error embedding text: {e}")
|
||||
logger.debug(f"Text: {texts}")
|
||||
raise
|
||||
return embed_chunks(chunked_chunks, model, input_type)
|
||||
|
||||
|
||||
def embed_mixed(
|
||||
items: list[extract.MulitmodalChunk],
|
||||
items: list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
||||
if isinstance(item, str):
|
||||
return [
|
||||
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
|
||||
]
|
||||
return [item]
|
||||
|
||||
chunks = [c for item in items for c in to_chunks(item)]
|
||||
return embed_chunks([chunks], model, input_type)
|
||||
chunked_chunks = [break_chunk(item, chunk_size) for item in items]
|
||||
return embed_chunks(chunked_chunks, model, input_type)
|
||||
|
||||
|
||||
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||
@ -87,6 +85,9 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||
|
||||
def embed_source_item(item: SourceItem) -> list[Chunk]:
|
||||
chunks = list(item.data_chunks())
|
||||
logger.error(
|
||||
f"Embedding source item: {item.id} - {[(c.embedding_model, c.collection_name, c.chunks) for c in chunks]}"
|
||||
)
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
|
@ -21,7 +21,7 @@ class DataChunk:
|
||||
data: Sequence[MulitmodalChunk]
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
mime_type: str = "text/plain"
|
||||
collection_name: str | None = None
|
||||
modality: str | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -8,7 +8,7 @@ class Evidence(TypedDict):
|
||||
|
||||
|
||||
def generate_semantic_text(
|
||||
subject: str, observation_type: str, content: str, evidence: Evidence
|
||||
subject: str, observation_type: str, content: str, evidence: Evidence | None = None
|
||||
) -> str:
|
||||
"""Generate text optimized for semantic similarity search."""
|
||||
parts = [
|
||||
@ -54,8 +54,9 @@ def generate_temporal_text(
|
||||
f"Time: {time_of_day} on {day_of_week} ({time_period})",
|
||||
f"Subject: {subject}",
|
||||
f"Observation: {content}",
|
||||
f"Confidence: {confidence}",
|
||||
]
|
||||
if confidence:
|
||||
parts.append(f"Confidence: {confidence}")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
|
@ -9,6 +9,8 @@ logger = logging.getLogger(__name__)
|
||||
TAGS_PROMPT = """
|
||||
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
|
||||
|
||||
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||
|
||||
Return your response as JSON with this format:
|
||||
{{
|
||||
"summary": "{summary}",
|
||||
@ -23,6 +25,8 @@ SUMMARY_PROMPT = """
|
||||
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
|
||||
Also provide 3-5 relevant tags that capture the main topics or themes.
|
||||
|
||||
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||
|
||||
Return your response as JSON with this format:
|
||||
{{
|
||||
"summary": "your summary here",
|
||||
|
@ -1,62 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Any
|
||||
from fastapi import UploadFile
|
||||
import httpx
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
SERVER = "http://localhost:8000"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
mcp = FastMCP("memory")
|
||||
|
||||
|
||||
async def make_request(
|
||||
path: str,
|
||||
method: str,
|
||||
data: dict | None = None,
|
||||
json: dict | None = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
) -> httpx.Response:
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.request(
|
||||
method,
|
||||
f"{SERVER}/{path}",
|
||||
data=data,
|
||||
json=json,
|
||||
files=files, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
|
||||
return await make_request(path, "POST", data=data)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(
|
||||
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||
) -> list[dict[str, Any]]:
|
||||
logger.error(f"Searching for {query}")
|
||||
resp = await post_data(
|
||||
"search",
|
||||
{
|
||||
"query": query,
|
||||
"previews": previews,
|
||||
"modalities": modalities,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
return resp.json()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize and run the server
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--server", type=str)
|
||||
args = args.parse_args()
|
||||
|
||||
SERVER = args.server
|
||||
|
||||
mcp.run(transport=args.transport)
|
Loading…
x
Reference in New Issue
Block a user