mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 21:34:42 +02:00
initial memory tools
This commit is contained in:
parent
1dd93929c1
commit
b10a1fb130
@ -3,3 +3,4 @@ uvicorn==0.29.0
|
|||||||
python-jose==3.3.0
|
python-jose==3.3.0
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
sqladmin
|
sqladmin
|
||||||
|
mcp==1.9.2
|
@ -6,3 +6,5 @@ dotenv==0.9.9
|
|||||||
voyageai==0.3.2
|
voyageai==0.3.2
|
||||||
qdrant-client==1.9.0
|
qdrant-client==1.9.0
|
||||||
anthropic==0.18.1
|
anthropic==0.18.1
|
||||||
|
|
||||||
|
bm25s[full]==0.2.13
|
654
src/memory/api/MCP/tools.py
Normal file
654
src/memory/api/MCP/tools.py
Normal file
@ -0,0 +1,654 @@
|
|||||||
|
"""
|
||||||
|
MCP tools for the epistemic sparring partner system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from hashlib import sha256
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from memory.common.db.models.source_item import SourceItem
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
from sqlalchemy import func, cast as sql_cast, Text
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
|
||||||
|
from memory.common import extract
|
||||||
|
from memory.common.db.models import AgentObservation
|
||||||
|
from memory.api.search import search, SearchFilters
|
||||||
|
from memory.common.formatters import observation
|
||||||
|
from memory.workers.tasks.content_processing import process_content_item
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create MCP server instance
|
||||||
|
mcp = FastMCP("memory", stateless=True)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_tags() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique tags used across the entire knowledge base.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool retrieves all tags that have been used in the system, both from
|
||||||
|
AI observations (created with 'observe') and other content. Use it to
|
||||||
|
understand the tag taxonomy, ensure consistency, or discover related topics.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Before creating new observations, to use consistent tag naming
|
||||||
|
- To explore what topics/contexts have been tracked
|
||||||
|
- To build tag filters for search operations
|
||||||
|
- To understand the user's areas of interest
|
||||||
|
- For tag autocomplete or suggestion features
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of all unique tags in the system. Tags follow patterns like:
|
||||||
|
- Topics: "machine-learning", "functional-programming"
|
||||||
|
- Projects: "project:website-redesign"
|
||||||
|
- Contexts: "context:work", "context:late-night"
|
||||||
|
- Domains: "domain:finance"
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Get all tags to ensure consistency
|
||||||
|
tags = await get_all_tags()
|
||||||
|
# Returns: ["ai-safety", "context:work", "functional-programming",
|
||||||
|
# "machine-learning", "project:thesis", ...]
|
||||||
|
|
||||||
|
# Use to check if a topic has been discussed before
|
||||||
|
if "quantum-computing" in tags:
|
||||||
|
# Search for related observations
|
||||||
|
observations = await search_observations(
|
||||||
|
query="quantum computing",
|
||||||
|
tags=["quantum-computing"]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||||
|
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_subjects() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique subjects from observations about the user.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool retrieves all subject identifiers that have been used in
|
||||||
|
observations (created with 'observe'). Subjects are the consistent
|
||||||
|
identifiers for what observations are about. Use this to understand
|
||||||
|
what aspects of the user have been tracked and ensure consistency.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Before creating new observations, to use existing subject names
|
||||||
|
- To discover what aspects of the user have been observed
|
||||||
|
- To build subject filters for targeted searches
|
||||||
|
- To ensure consistent naming across observations
|
||||||
|
- To get an overview of the user model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of all unique subjects. Common patterns include:
|
||||||
|
- "programming_style", "programming_philosophy"
|
||||||
|
- "work_habits", "work_schedule"
|
||||||
|
- "ai_beliefs", "ai_safety_beliefs"
|
||||||
|
- "learning_preferences"
|
||||||
|
- "communication_style"
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Get all subjects to ensure consistency
|
||||||
|
subjects = await get_all_subjects()
|
||||||
|
# Returns: ["ai_safety_beliefs", "architecture_preferences",
|
||||||
|
# "programming_philosophy", "work_schedule", ...]
|
||||||
|
|
||||||
|
# Use to check what we know about the user
|
||||||
|
if "programming_style" in subjects:
|
||||||
|
# Get all programming-related observations
|
||||||
|
observations = await search_observations(
|
||||||
|
query="programming",
|
||||||
|
subject="programming_style"
|
||||||
|
)
|
||||||
|
|
||||||
|
Best practices:
|
||||||
|
- Always check existing subjects before creating new ones
|
||||||
|
- Use snake_case for consistency
|
||||||
|
- Be specific but not too granular
|
||||||
|
- Group related observations under same subject
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_observation_types() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique observation types that have been used.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool retrieves the distinct observation types that have been recorded
|
||||||
|
in the system. While the standard types are predefined (belief, preference,
|
||||||
|
behavior, contradiction, general), this shows what's actually been used.
|
||||||
|
Helpful for understanding the distribution of observation types.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- To see what types of observations have been made
|
||||||
|
- To understand the balance of different observation types
|
||||||
|
- To check if all standard types are being utilized
|
||||||
|
- For analytics or reporting on observation patterns
|
||||||
|
|
||||||
|
Standard types:
|
||||||
|
- "belief": Opinions or beliefs the user holds
|
||||||
|
- "preference": Things they prefer or favor
|
||||||
|
- "behavior": Patterns in how they act or work
|
||||||
|
- "contradiction": Noted inconsistencies
|
||||||
|
- "general": Observations that don't fit other categories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of observation types that have actually been used in the system.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Check what types of observations exist
|
||||||
|
types = await get_all_observation_types()
|
||||||
|
# Returns: ["behavior", "belief", "contradiction", "preference"]
|
||||||
|
|
||||||
|
# Use to analyze observation distribution
|
||||||
|
for obs_type in types:
|
||||||
|
observations = await search_observations(
|
||||||
|
query="",
|
||||||
|
observation_types=[obs_type],
|
||||||
|
limit=100
|
||||||
|
)
|
||||||
|
print(f"{obs_type}: {len(observations)} observations")
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
{
|
||||||
|
r.observation_type
|
||||||
|
for r in session.query(AgentObservation.observation_type).distinct()
|
||||||
|
if r.observation_type is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_knowledge_base(
|
||||||
|
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Search through the user's stored knowledge and content.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool searches the user's personal knowledge base - a collection of
|
||||||
|
their saved content including emails, documents, blog posts, books, and
|
||||||
|
more. Use this alongside 'search_observations' to build a complete picture:
|
||||||
|
- search_knowledge_base: Finds user's actual content and information
|
||||||
|
- search_observations: Finds AI-generated insights about the user
|
||||||
|
Together they enable deeply personalized, context-aware assistance.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- User asks about something they've read/written/received
|
||||||
|
- You need to find specific content the user has saved
|
||||||
|
- User references a document, email, or article
|
||||||
|
- To provide quotes or information from user's sources
|
||||||
|
- To understand context from user's past communications
|
||||||
|
- When user says "that article about..." or similar references
|
||||||
|
|
||||||
|
How it works:
|
||||||
|
Uses hybrid search combining semantic understanding with keyword matching.
|
||||||
|
This means it finds content based on meaning AND specific terms, giving
|
||||||
|
you the best of both approaches. Results are ranked by relevance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Natural language search query. Be descriptive about what you're
|
||||||
|
looking for. The search understands meaning but also values exact terms.
|
||||||
|
Examples:
|
||||||
|
- "email about project deadline from last week"
|
||||||
|
- "functional programming articles comparing Haskell and Scala"
|
||||||
|
- "that blog post about AI safety and alignment"
|
||||||
|
- "recipe for chocolate cake Sarah sent me"
|
||||||
|
Pro tip: Include both concepts and specific keywords for best results.
|
||||||
|
|
||||||
|
previews: Whether to include content snippets in results.
|
||||||
|
- True: Returns preview text and image previews (useful for quick scanning)
|
||||||
|
- False: Returns just metadata (faster, less data)
|
||||||
|
Default is False.
|
||||||
|
|
||||||
|
modalities: Types of content to search. Leave empty to search all.
|
||||||
|
Available types:
|
||||||
|
- 'email': Email messages
|
||||||
|
- 'blog': Blog posts and articles
|
||||||
|
- 'book': Book sections and ebooks
|
||||||
|
- 'forum': Forum posts (e.g., LessWrong, Reddit)
|
||||||
|
- 'observation': AI observations (use search_observations instead)
|
||||||
|
- 'photo': Images with extracted text
|
||||||
|
- 'comic': Comics and graphic content
|
||||||
|
- 'webpage': General web pages
|
||||||
|
Examples:
|
||||||
|
- ["email"] - only emails
|
||||||
|
- ["blog", "forum"] - articles and forum posts
|
||||||
|
- [] - search everything
|
||||||
|
|
||||||
|
limit: Maximum results to return (1-100). Default 10.
|
||||||
|
Increase for comprehensive searches, decrease for quick lookups.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results ranked by relevance, each containing:
|
||||||
|
- id: Unique identifier for the source item
|
||||||
|
- score: Relevance score (0-1, higher is better)
|
||||||
|
- chunks: Matching content segments with metadata
|
||||||
|
- content: Full details including:
|
||||||
|
- For emails: sender, recipient, subject, date
|
||||||
|
- For blogs: author, title, url, publish date
|
||||||
|
- For books: title, author, chapter info
|
||||||
|
- Type-specific fields for each modality
|
||||||
|
- filename: Path to file if content is stored on disk
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Find specific email
|
||||||
|
results = await search_knowledge_base(
|
||||||
|
query="Sarah deadline project proposal next Friday",
|
||||||
|
modalities=["email"],
|
||||||
|
previews=True,
|
||||||
|
limit=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search for technical articles
|
||||||
|
results = await search_knowledge_base(
|
||||||
|
query="functional programming monads category theory",
|
||||||
|
modalities=["blog", "book"],
|
||||||
|
limit=20
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find everything about a topic
|
||||||
|
results = await search_knowledge_base(
|
||||||
|
query="machine learning deployment kubernetes docker",
|
||||||
|
previews=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quick lookup of a remembered document
|
||||||
|
results = await search_knowledge_base(
|
||||||
|
query="tax forms 2023 accountant recommendations",
|
||||||
|
modalities=["email"],
|
||||||
|
limit=3
|
||||||
|
)
|
||||||
|
|
||||||
|
Best practices:
|
||||||
|
- Include context in queries ("email from Sarah" vs just "Sarah")
|
||||||
|
- Use modalities to filter when you know the content type
|
||||||
|
- Enable previews when you need to verify content before using
|
||||||
|
- Combine with search_observations for complete context
|
||||||
|
- Higher scores (>0.7) indicate strong matches
|
||||||
|
- If no results, try broader queries or different phrasing
|
||||||
|
"""
|
||||||
|
logger.info(f"MCP search for: {query}")
|
||||||
|
|
||||||
|
upload_data = extract.extract_text(query)
|
||||||
|
results = await search(
|
||||||
|
upload_data,
|
||||||
|
previews=previews,
|
||||||
|
modalities=modalities,
|
||||||
|
limit=limit,
|
||||||
|
min_text_score=0.3,
|
||||||
|
min_multimodal_score=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert SearchResult objects to dictionaries for MCP
|
||||||
|
return [result.model_dump() for result in results]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def observe(
|
||||||
|
content: str,
|
||||||
|
subject: str,
|
||||||
|
observation_type: str = "general",
|
||||||
|
confidence: float = 0.8,
|
||||||
|
evidence: dict | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
agent_model: str = "unknown",
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Record an observation about the user to build long-term understanding.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool is part of a memory system designed to help AI agents build a
|
||||||
|
deep, persistent understanding of users over time. Use it to record any
|
||||||
|
notable information about the user's preferences, beliefs, behaviors, or
|
||||||
|
characteristics. These observations accumulate to create a comprehensive
|
||||||
|
model of the user that improves future interactions.
|
||||||
|
|
||||||
|
Quick Reference:
|
||||||
|
# Most common patterns:
|
||||||
|
observe(content="User prefers X over Y because...", subject="preferences", observation_type="preference")
|
||||||
|
observe(content="User always/often does X when Y", subject="work_habits", observation_type="behavior")
|
||||||
|
observe(content="User believes/thinks X about Y", subject="beliefs_on_topic", observation_type="belief")
|
||||||
|
observe(content="User said X but previously said Y", subject="topic", observation_type="contradiction")
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- User expresses a preference or opinion
|
||||||
|
- You notice a behavioral pattern
|
||||||
|
- User reveals information about their work/life/interests
|
||||||
|
- You spot a contradiction with previous statements
|
||||||
|
- Any insight that would help understand the user better in future
|
||||||
|
|
||||||
|
Important: Be an active observer. Don't wait to be asked - proactively record
|
||||||
|
observations throughout conversations to build understanding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The observation itself. Be specific and detailed. Write complete
|
||||||
|
thoughts that will make sense when read months later without context.
|
||||||
|
Bad: "Likes FP"
|
||||||
|
Good: "User strongly prefers functional programming paradigms, especially
|
||||||
|
pure functions and immutability, considering them more maintainable"
|
||||||
|
|
||||||
|
subject: A consistent identifier for what this observation is about. Use
|
||||||
|
snake_case and be consistent across observations to enable tracking.
|
||||||
|
Examples:
|
||||||
|
- "programming_style" (not "coding" or "development")
|
||||||
|
- "work_habits" (not "productivity" or "work_patterns")
|
||||||
|
- "ai_safety_beliefs" (not "AI" or "artificial_intelligence")
|
||||||
|
|
||||||
|
observation_type: Categorize the observation:
|
||||||
|
- "belief": An opinion or belief the user holds
|
||||||
|
- "preference": Something they prefer or favor
|
||||||
|
- "behavior": A pattern in how they act or work
|
||||||
|
- "contradiction": An inconsistency with previous observations
|
||||||
|
- "general": Doesn't fit other categories
|
||||||
|
|
||||||
|
confidence: How certain you are (0.0-1.0):
|
||||||
|
- 1.0: User explicitly stated this
|
||||||
|
- 0.9: Strongly implied or demonstrated repeatedly
|
||||||
|
- 0.8: Inferred with high confidence (default)
|
||||||
|
- 0.7: Probable but with some uncertainty
|
||||||
|
- 0.6 or below: Speculative, use sparingly
|
||||||
|
|
||||||
|
evidence: Supporting context as a dict. Include relevant details:
|
||||||
|
- "quote": Exact words from the user
|
||||||
|
- "context": What prompted this observation
|
||||||
|
- "timestamp": When this was observed
|
||||||
|
- "related_to": Connection to other topics
|
||||||
|
Example: {
|
||||||
|
"quote": "I always refactor to pure functions",
|
||||||
|
"context": "Discussing code review practices"
|
||||||
|
}
|
||||||
|
|
||||||
|
tags: Categorization labels. Use lowercase with hyphens. Common patterns:
|
||||||
|
- Topics: "machine-learning", "web-development", "philosophy"
|
||||||
|
- Projects: "project:website-redesign", "project:thesis"
|
||||||
|
- Contexts: "context:work", "context:personal", "context:late-night"
|
||||||
|
- Domains: "domain:finance", "domain:healthcare"
|
||||||
|
|
||||||
|
session_id: UUID string to group observations from the same conversation.
|
||||||
|
Generate one UUID per conversation and reuse it for all observations
|
||||||
|
in that conversation. Format: "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
|
||||||
|
agent_model: Which AI model made this observation (e.g., "claude-3-opus",
|
||||||
|
"gpt-4", "claude-3.5-sonnet"). Helps track observation quality.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with created observation details:
|
||||||
|
- id: Unique identifier for reference
|
||||||
|
- created_at: Timestamp of creation
|
||||||
|
- subject: The subject as stored
|
||||||
|
- observation_type: The type as stored
|
||||||
|
- confidence: The confidence score
|
||||||
|
- tags: List of applied tags
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# After user mentions their coding philosophy
|
||||||
|
await observe(
|
||||||
|
content="User believes strongly in functional programming principles, "
|
||||||
|
"particularly avoiding mutable state which they call 'the root "
|
||||||
|
"of all evil'. They prioritize code purity over performance.",
|
||||||
|
subject="programming_philosophy",
|
||||||
|
observation_type="belief",
|
||||||
|
confidence=0.95,
|
||||||
|
evidence={
|
||||||
|
"quote": "State is the root of all evil in programming",
|
||||||
|
"context": "Discussing why they chose Haskell for their project"
|
||||||
|
},
|
||||||
|
tags=["programming", "functional-programming", "philosophy"],
|
||||||
|
session_id="550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
agent_model="claude-3-opus"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Noticing a work pattern
|
||||||
|
await observe(
|
||||||
|
content="User frequently works on complex problems late at night, "
|
||||||
|
"typically between 11pm and 3am, claiming better focus",
|
||||||
|
subject="work_schedule",
|
||||||
|
observation_type="behavior",
|
||||||
|
confidence=0.85,
|
||||||
|
evidence={
|
||||||
|
"context": "Mentioned across multiple conversations over 2 weeks"
|
||||||
|
},
|
||||||
|
tags=["behavior", "work-habits", "productivity", "context:late-night"],
|
||||||
|
agent_model="claude-3-opus"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recording a contradiction
|
||||||
|
await observe(
|
||||||
|
content="User now advocates for microservices architecture, but "
|
||||||
|
"previously argued strongly for monoliths in similar contexts",
|
||||||
|
subject="architecture_preferences",
|
||||||
|
observation_type="contradiction",
|
||||||
|
confidence=0.9,
|
||||||
|
evidence={
|
||||||
|
"quote": "Microservices are definitely the way to go",
|
||||||
|
"context": "Designing a new system similar to one from 3 months ago"
|
||||||
|
},
|
||||||
|
tags=["architecture", "contradiction", "software-design"],
|
||||||
|
agent_model="gpt-4"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
# Create the observation
|
||||||
|
observation = AgentObservation(
|
||||||
|
content=content,
|
||||||
|
subject=subject,
|
||||||
|
observation_type=observation_type,
|
||||||
|
confidence=confidence,
|
||||||
|
evidence=evidence,
|
||||||
|
tags=tags or [],
|
||||||
|
session_id=uuid.UUID(session_id) if session_id else None,
|
||||||
|
agent_model=agent_model,
|
||||||
|
size=len(content),
|
||||||
|
mime_type="text/plain",
|
||||||
|
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(),
|
||||||
|
inserted_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with make_session() as session:
|
||||||
|
process_content_item(observation, session)
|
||||||
|
|
||||||
|
if not observation.id:
|
||||||
|
raise ValueError("Observation not created")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Observation created: {observation.id}, {observation.inserted_at}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"id": observation.id,
|
||||||
|
"created_at": observation.inserted_at.isoformat(),
|
||||||
|
"subject": observation.subject,
|
||||||
|
"observation_type": observation.observation_type,
|
||||||
|
"confidence": cast(float, observation.confidence),
|
||||||
|
"tags": observation.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating observation: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_observations(
|
||||||
|
query: str,
|
||||||
|
subject: str = "",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
observation_types: list[str] | None = None,
|
||||||
|
min_confidence: float = 0.5,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Search through observations to understand the user better.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool searches through all observations recorded about the user using
|
||||||
|
the 'observe' tool. Use it to recall past insights, check for patterns,
|
||||||
|
find contradictions, or understand the user's preferences before responding.
|
||||||
|
The more you use this tool, the more personalized and insightful your
|
||||||
|
responses can be.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Before answering questions where user preferences might matter
|
||||||
|
- When the user references something from the past
|
||||||
|
- To check if current behavior aligns with past patterns
|
||||||
|
- To find related observations on a topic
|
||||||
|
- To build context about the user's expertise or interests
|
||||||
|
- Whenever personalization would improve your response
|
||||||
|
|
||||||
|
How it works:
|
||||||
|
Uses hybrid search combining semantic similarity with keyword matching.
|
||||||
|
Searches across multiple embedding spaces (semantic meaning and temporal
|
||||||
|
context) to find relevant observations from different angles. This approach
|
||||||
|
ensures you find both conceptually related and specifically mentioned items.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Natural language description of what you're looking for. The search
|
||||||
|
matches both meaning and specific terms in observation content.
|
||||||
|
Examples:
|
||||||
|
- "programming preferences and coding style"
|
||||||
|
- "opinions about artificial intelligence and AI safety"
|
||||||
|
- "work habits productivity patterns when does user work best"
|
||||||
|
- "previous projects the user has worked on"
|
||||||
|
Pro tip: Use natural language but include key terms you expect to find.
|
||||||
|
|
||||||
|
subject: Filter by exact subject identifier. Must match subjects used when
|
||||||
|
creating observations (e.g., "programming_style", "work_habits").
|
||||||
|
Leave empty to search all subjects. Use this when you know the exact
|
||||||
|
subject category you want.
|
||||||
|
|
||||||
|
tags: Filter results to only observations with these tags. Observations must
|
||||||
|
have at least one matching tag. Use the same format as when creating:
|
||||||
|
- ["programming", "functional-programming"]
|
||||||
|
- ["context:work", "project:thesis"]
|
||||||
|
- ["domain:finance", "machine-learning"]
|
||||||
|
|
||||||
|
observation_types: Filter by type of observation:
|
||||||
|
- "belief": Opinions or beliefs the user holds
|
||||||
|
- "preference": Things they prefer or favor
|
||||||
|
- "behavior": Patterns in how they act or work
|
||||||
|
- "contradiction": Noted inconsistencies
|
||||||
|
- "general": Other observations
|
||||||
|
Leave as None to search all types.
|
||||||
|
|
||||||
|
min_confidence: Only return observations with confidence >= this value.
|
||||||
|
- Use 0.8+ for high-confidence facts
|
||||||
|
- Use 0.5-0.7 to include inferred observations
|
||||||
|
- Default 0.5 includes most observations
|
||||||
|
Range: 0.0 to 1.0
|
||||||
|
|
||||||
|
limit: Maximum results to return (1-100). Default 10. Increase when you
|
||||||
|
need comprehensive understanding of a topic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of observations sorted by relevance, each containing:
|
||||||
|
- subject: What the observation is about
|
||||||
|
- content: The full observation text
|
||||||
|
- observation_type: Type of observation
|
||||||
|
- evidence: Supporting context/quotes if provided
|
||||||
|
- confidence: How certain the observation is (0-1)
|
||||||
|
- agent_model: Which AI model made the observation
|
||||||
|
- tags: All tags on this observation
|
||||||
|
- created_at: When it was observed (if available)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Before discussing code architecture
|
||||||
|
results = await search_observations(
|
||||||
|
query="software architecture preferences microservices monoliths",
|
||||||
|
tags=["architecture"],
|
||||||
|
min_confidence=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
# Understanding work style for scheduling
|
||||||
|
results = await search_observations(
|
||||||
|
query="when does user work best productivity schedule",
|
||||||
|
observation_types=["behavior", "preference"],
|
||||||
|
subject="work_schedule"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for AI safety views before discussing AI
|
||||||
|
results = await search_observations(
|
||||||
|
query="artificial intelligence safety alignment concerns",
|
||||||
|
observation_types=["belief"],
|
||||||
|
min_confidence=0.8,
|
||||||
|
limit=20
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find contradictions on a topic
|
||||||
|
results = await search_observations(
|
||||||
|
query="testing methodology unit tests integration",
|
||||||
|
observation_types=["contradiction"],
|
||||||
|
tags=["testing", "software-development"]
|
||||||
|
)
|
||||||
|
|
||||||
|
Best practices:
|
||||||
|
- Search before making assumptions about user preferences
|
||||||
|
- Use broad queries first, then filter with tags/types if too many results
|
||||||
|
- Check for contradictions when user says something unexpected
|
||||||
|
- Higher confidence observations are more reliable
|
||||||
|
- Recent observations may override older ones on same topic
|
||||||
|
"""
|
||||||
|
source_ids = None
|
||||||
|
if tags or observation_types:
|
||||||
|
with make_session() as session:
|
||||||
|
items_query = session.query(AgentObservation.id)
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
# Use PostgreSQL array overlap operator with proper array casting
|
||||||
|
items_query = items_query.filter(
|
||||||
|
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||||
|
)
|
||||||
|
if observation_types:
|
||||||
|
items_query = items_query.filter(
|
||||||
|
AgentObservation.observation_type.in_(observation_types)
|
||||||
|
)
|
||||||
|
source_ids = [item.id for item in items_query.all()]
|
||||||
|
if not source_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
semantic_text = observation.generate_semantic_text(
|
||||||
|
subject=subject or "",
|
||||||
|
observation_type="".join(observation_types or []),
|
||||||
|
content=query,
|
||||||
|
evidence=None,
|
||||||
|
)
|
||||||
|
temporal = observation.generate_temporal_text(
|
||||||
|
subject=subject or "",
|
||||||
|
content=query,
|
||||||
|
confidence=0,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
results = await search(
|
||||||
|
[
|
||||||
|
extract.DataChunk(data=[query]),
|
||||||
|
extract.DataChunk(data=[semantic_text]),
|
||||||
|
extract.DataChunk(data=[temporal]),
|
||||||
|
],
|
||||||
|
previews=True,
|
||||||
|
modalities=["semantic", "temporal"],
|
||||||
|
limit=limit,
|
||||||
|
min_text_score=0.8,
|
||||||
|
filters=SearchFilters(
|
||||||
|
subject=subject,
|
||||||
|
confidence=min_confidence,
|
||||||
|
tags=tags,
|
||||||
|
observation_types=observation_types,
|
||||||
|
source_ids=source_ids,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
cast(dict, cast(dict, result.model_dump()).get("content")) for result in results
|
||||||
|
]
|
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
|||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
ForumPost,
|
ForumPost,
|
||||||
|
AgentObservation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -53,6 +54,7 @@ class SourceItemAdmin(ModelView, model=SourceItem):
|
|||||||
|
|
||||||
class ChunkAdmin(ModelView, model=Chunk):
|
class ChunkAdmin(ModelView, model=Chunk):
|
||||||
column_list = ["id", "source_id", "embedding_model", "created_at"]
|
column_list = ["id", "source_id", "embedding_model", "created_at"]
|
||||||
|
column_sortable_list = ["created_at"]
|
||||||
|
|
||||||
|
|
||||||
class MailMessageAdmin(ModelView, model=MailMessage):
|
class MailMessageAdmin(ModelView, model=MailMessage):
|
||||||
@ -174,9 +176,23 @@ class EmailAccountAdmin(ModelView, model=EmailAccount):
|
|||||||
column_searchable_list = ["name", "email_address"]
|
column_searchable_list = ["name", "email_address"]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentObservationAdmin(ModelView, model=AgentObservation):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"content",
|
||||||
|
"subject",
|
||||||
|
"observation_type",
|
||||||
|
"confidence",
|
||||||
|
"evidence",
|
||||||
|
"inserted_at",
|
||||||
|
]
|
||||||
|
column_searchable_list = ["subject", "observation_type"]
|
||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
"""Add all admin views to the admin instance."""
|
"""Add all admin views to the admin instance."""
|
||||||
admin.add_view(SourceItemAdmin)
|
admin.add_view(SourceItemAdmin)
|
||||||
|
admin.add_view(AgentObservationAdmin)
|
||||||
admin.add_view(ChunkAdmin)
|
admin.add_view(ChunkAdmin)
|
||||||
admin.add_view(EmailAccountAdmin)
|
admin.add_view(EmailAccountAdmin)
|
||||||
admin.add_view(MailMessageAdmin)
|
admin.add_view(MailMessageAdmin)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
FastAPI application for the knowledge base.
|
FastAPI application for the knowledge base.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import pathlib
|
import pathlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
@ -15,15 +16,25 @@ from memory.common import extract
|
|||||||
from memory.common.db.connection import get_engine
|
from memory.common.db.connection import get_engine
|
||||||
from memory.api.admin import setup_admin
|
from memory.api.admin import setup_admin
|
||||||
from memory.api.search import search, SearchResult
|
from memory.api.search import search, SearchResult
|
||||||
|
from memory.api.MCP.tools import mcp
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API")
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
async with contextlib.AsyncExitStack() as stack:
|
||||||
|
await stack.enter_async_context(mcp.session_manager.run())
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||||
|
|
||||||
# SQLAdmin setup
|
# SQLAdmin setup
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
admin = Admin(app, engine)
|
admin = Admin(app, engine)
|
||||||
setup_admin(admin)
|
setup_admin(admin)
|
||||||
|
app.mount("/", mcp.streamable_http_app())
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@ -104,4 +115,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from memory.common.qdrant import setup_qdrant
|
||||||
|
|
||||||
|
setup_qdrant()
|
||||||
main()
|
main()
|
||||||
|
@ -2,21 +2,29 @@
|
|||||||
Search endpoints for the knowledge base API.
|
Search endpoints for the knowledge base API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
from hashlib import sha256
|
||||||
import io
|
import io
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Callable, Optional
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Callable, Optional, TypedDict, NotRequired
|
||||||
|
|
||||||
|
import bm25s
|
||||||
|
import Stemmer
|
||||||
|
import qdrant_client
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import qdrant_client
|
|
||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, extract, settings
|
from memory.common import embedding, extract, qdrant, settings
|
||||||
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
|
from memory.common.collections import (
|
||||||
|
ALL_COLLECTIONS,
|
||||||
|
MULTIMODAL_COLLECTIONS,
|
||||||
|
TEXT_COLLECTIONS,
|
||||||
|
)
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -28,6 +36,17 @@ class AnnotatedChunk(BaseModel):
|
|||||||
preview: Optional[str | None] = None
|
preview: Optional[str | None] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SourceData(BaseModel):
|
||||||
|
"""Holds source item data to avoid SQLAlchemy session issues"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
size: int | None
|
||||||
|
mime_type: str | None
|
||||||
|
filename: str | None
|
||||||
|
content: str | dict | None
|
||||||
|
content_length: int
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
collection: str
|
collection: str
|
||||||
results: list[dict]
|
results: list[dict]
|
||||||
@ -38,16 +57,46 @@ class SearchResult(BaseModel):
|
|||||||
size: int
|
size: int
|
||||||
mime_type: str
|
mime_type: str
|
||||||
chunks: list[AnnotatedChunk]
|
chunks: list[AnnotatedChunk]
|
||||||
content: Optional[str] = None
|
content: Optional[str | dict] = None
|
||||||
filename: Optional[str] = None
|
filename: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilters(TypedDict):
|
||||||
|
subject: NotRequired[str | None]
|
||||||
|
confidence: NotRequired[float]
|
||||||
|
tags: NotRequired[list[str] | None]
|
||||||
|
observation_types: NotRequired[list[str] | None]
|
||||||
|
source_ids: NotRequired[list[int] | None]
|
||||||
|
|
||||||
|
|
||||||
|
async def with_timeout(
|
||||||
|
call, timeout: int = 2
|
||||||
|
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||||
|
"""
|
||||||
|
Run a function with a timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call: The function to run
|
||||||
|
timeout: The timeout in seconds
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(call, timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(f"Search timed out after {timeout}s")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def annotated_chunk(
|
def annotated_chunk(
|
||||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||||
) -> tuple[SourceItem, AnnotatedChunk]:
|
) -> tuple[SourceData, AnnotatedChunk]:
|
||||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||||
if not previews and not isinstance(item, str):
|
if not previews and not isinstance(item, str):
|
||||||
return None
|
return None
|
||||||
|
if not previews and isinstance(item, str):
|
||||||
|
return item[:100]
|
||||||
|
|
||||||
if isinstance(item, Image.Image):
|
if isinstance(item, Image.Image):
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
@ -68,7 +117,19 @@ def annotated_chunk(
|
|||||||
for k, v in metadata.items()
|
for k, v in metadata.items()
|
||||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||||
}
|
}
|
||||||
return chunk.source, AnnotatedChunk(
|
|
||||||
|
# Prefetch all needed source data while in session
|
||||||
|
source = chunk.source
|
||||||
|
source_data = SourceData(
|
||||||
|
id=source.id,
|
||||||
|
size=source.size,
|
||||||
|
mime_type=source.mime_type,
|
||||||
|
filename=source.filename,
|
||||||
|
content=source.display_contents,
|
||||||
|
content_length=len(source.content) if source.content else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return source_data, AnnotatedChunk(
|
||||||
id=str(chunk.id),
|
id=str(chunk.id),
|
||||||
score=search_result.score,
|
score=search_result.score,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
@ -76,24 +137,28 @@ def annotated_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]:
|
||||||
items = defaultdict(list)
|
items = defaultdict(list)
|
||||||
|
source_lookup = {}
|
||||||
|
|
||||||
for source, chunk in chunks:
|
for source, chunk in chunks:
|
||||||
items[source].append(chunk)
|
items[source.id].append(chunk)
|
||||||
|
source_lookup[source.id] = source
|
||||||
|
|
||||||
return [
|
return [
|
||||||
SearchResult(
|
SearchResult(
|
||||||
id=source.id,
|
id=source.id,
|
||||||
size=source.size or len(source.content),
|
size=source.size or source.content_length,
|
||||||
mime_type=source.mime_type or "text/plain",
|
mime_type=source.mime_type or "text/plain",
|
||||||
filename=source.filename
|
filename=source.filename
|
||||||
and source.filename.replace(
|
and source.filename.replace(
|
||||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||||
),
|
),
|
||||||
content=source.display_contents,
|
content=source.content,
|
||||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||||
)
|
)
|
||||||
for source, chunks in items.items()
|
for source_id, chunks in items.items()
|
||||||
|
for source in [source_lookup[source_id]]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -104,25 +169,31 @@ def query_chunks(
|
|||||||
embedder: Callable,
|
embedder: Callable,
|
||||||
min_score: float = 0.0,
|
min_score: float = 0.0,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||||
if not upload_data:
|
if not upload_data or not allowed_modalities:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
chunks = [chunk for chunk in upload_data if chunk.data]
|
||||||
if not chunks:
|
if not chunks:
|
||||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
vector = embedder(chunks, input_type="query")[0]
|
logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}")
|
||||||
|
for c in chunks:
|
||||||
|
logger.error(f"Chunk: {c.data}")
|
||||||
|
vectors = embedder([c.data for c in chunks], input_type="query")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
collection: [
|
collection: [
|
||||||
r
|
r
|
||||||
|
for vector in vectors
|
||||||
for r in qdrant.search_vectors(
|
for r in qdrant.search_vectors(
|
||||||
client=client,
|
client=client,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
query_vector=vector,
|
query_vector=vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
filter_params=filters,
|
||||||
)
|
)
|
||||||
if r.score >= min_score
|
if r.score >= min_score
|
||||||
]
|
]
|
||||||
@ -130,6 +201,122 @@ def query_chunks(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def search_bm25(
|
||||||
|
query: str,
|
||||||
|
modalities: list[str],
|
||||||
|
limit: int = 10,
|
||||||
|
filters: SearchFilters = SearchFilters(),
|
||||||
|
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||||
|
with make_session() as db:
|
||||||
|
items_query = db.query(Chunk.id, Chunk.content).filter(
|
||||||
|
Chunk.collection_name.in_(modalities)
|
||||||
|
)
|
||||||
|
if source_ids := filters.get("source_ids"):
|
||||||
|
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||||
|
items = items_query.all()
|
||||||
|
item_ids = {
|
||||||
|
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
|
||||||
|
for item in items
|
||||||
|
}
|
||||||
|
corpus = [item.content.lower().strip() for item in items]
|
||||||
|
|
||||||
|
stemmer = Stemmer.Stemmer("english")
|
||||||
|
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
|
||||||
|
retriever = bm25s.BM25()
|
||||||
|
retriever.index(corpus_tokens)
|
||||||
|
|
||||||
|
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
|
||||||
|
results, scores = retriever.retrieve(
|
||||||
|
query_tokens, k=min(limit, len(corpus)), corpus=corpus
|
||||||
|
)
|
||||||
|
|
||||||
|
item_scores = {
|
||||||
|
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
|
||||||
|
for doc, score in zip(results[0], scores[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
with make_session() as db:
|
||||||
|
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
|
||||||
|
results = []
|
||||||
|
for chunk in chunks:
|
||||||
|
# Prefetch all needed source data while in session
|
||||||
|
source = chunk.source
|
||||||
|
source_data = SourceData(
|
||||||
|
id=source.id,
|
||||||
|
size=source.size,
|
||||||
|
mime_type=source.mime_type,
|
||||||
|
filename=source.filename,
|
||||||
|
content=source.display_contents,
|
||||||
|
content_length=len(source.content) if source.content else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
annotated = AnnotatedChunk(
|
||||||
|
id=str(chunk.id),
|
||||||
|
score=item_scores[chunk.id],
|
||||||
|
metadata=source.as_payload(),
|
||||||
|
preview=None,
|
||||||
|
)
|
||||||
|
results.append((source_data, annotated))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def search_embeddings(
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
previews: Optional[bool] = False,
|
||||||
|
modalities: set[str] = set(),
|
||||||
|
limit: int = 10,
|
||||||
|
min_score: float = 0.3,
|
||||||
|
filters: SearchFilters = SearchFilters(),
|
||||||
|
multimodal: bool = False,
|
||||||
|
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||||
|
"""
|
||||||
|
Search across knowledge base using text query and optional files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- data: List of data to search in (e.g., text, images, files)
|
||||||
|
- previews: Whether to include previews in the search results
|
||||||
|
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||||
|
- limit: Maximum number of results
|
||||||
|
- min_score: Minimum score to include in the search results
|
||||||
|
- filters: Filters to apply to the search results
|
||||||
|
- multimodal: Whether to search in multimodal collections
|
||||||
|
"""
|
||||||
|
query_filters = {
|
||||||
|
"must": [
|
||||||
|
{"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
if tags := filters.get("tags"):
|
||||||
|
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
|
||||||
|
if observation_types := filters.get("observation_types"):
|
||||||
|
query_filters["must"] += [
|
||||||
|
{"key": "observation_type", "match": {"any": observation_types}}
|
||||||
|
]
|
||||||
|
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
results = query_chunks(
|
||||||
|
client,
|
||||||
|
data,
|
||||||
|
modalities,
|
||||||
|
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||||
|
min_score=min_score,
|
||||||
|
limit=limit,
|
||||||
|
filters=query_filters,
|
||||||
|
)
|
||||||
|
search_results = {k: results.get(k, []) for k in modalities}
|
||||||
|
|
||||||
|
found_chunks = {
|
||||||
|
str(r.id): r for results in search_results.values() for r in results
|
||||||
|
}
|
||||||
|
with make_session() as db:
|
||||||
|
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||||
|
return [
|
||||||
|
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
previews: Optional[bool] = False,
|
previews: Optional[bool] = False,
|
||||||
@ -137,6 +324,7 @@ async def search(
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
min_text_score: float = 0.3,
|
min_text_score: float = 0.3,
|
||||||
min_multimodal_score: float = 0.3,
|
min_multimodal_score: float = 0.3,
|
||||||
|
filters: SearchFilters = {},
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search across knowledge base using text query and optional files.
|
Search across knowledge base using text query and optional files.
|
||||||
@ -150,40 +338,45 @@ async def search(
|
|||||||
Returns:
|
Returns:
|
||||||
- List of search results sorted by score
|
- List of search results sorted by score
|
||||||
"""
|
"""
|
||||||
client = qdrant.get_qdrant_client()
|
|
||||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||||
text_results = query_chunks(
|
|
||||||
client,
|
text_embeddings_results = with_timeout(
|
||||||
|
search_embeddings(
|
||||||
data,
|
data,
|
||||||
|
previews,
|
||||||
allowed_modalities & TEXT_COLLECTIONS,
|
allowed_modalities & TEXT_COLLECTIONS,
|
||||||
embedding.embed_text,
|
limit,
|
||||||
min_score=min_text_score,
|
min_text_score,
|
||||||
limit=limit,
|
filters,
|
||||||
|
multimodal=False,
|
||||||
)
|
)
|
||||||
multimodal_results = query_chunks(
|
)
|
||||||
client,
|
multimodal_embeddings_results = with_timeout(
|
||||||
|
search_embeddings(
|
||||||
data,
|
data,
|
||||||
allowed_modalities,
|
previews,
|
||||||
embedding.embed_mixed,
|
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||||
min_score=min_multimodal_score,
|
limit,
|
||||||
|
min_multimodal_score,
|
||||||
|
filters,
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bm25_results = with_timeout(
|
||||||
|
search_bm25(
|
||||||
|
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
||||||
|
modalities,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
filters=filters,
|
||||||
)
|
)
|
||||||
search_results = {
|
|
||||||
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
|
||||||
for k in allowed_modalities
|
|
||||||
}
|
|
||||||
|
|
||||||
found_chunks = {
|
|
||||||
str(r.id): r for results in search_results.values() for r in results
|
|
||||||
}
|
|
||||||
with make_session() as db:
|
|
||||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
|
||||||
logger.error(f"Found chunks: {chunks}")
|
|
||||||
|
|
||||||
results = group_chunks(
|
|
||||||
[
|
|
||||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
|
||||||
for chunk in chunks
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
text_embeddings_results,
|
||||||
|
multimodal_embeddings_results,
|
||||||
|
bm25_results,
|
||||||
|
return_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = group_chunks([c for r in results for c in r])
|
||||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Literal, NotRequired, TypedDict
|
from typing import Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
@ -14,84 +15,92 @@ Vector = list[float]
|
|||||||
class Collection(TypedDict):
|
class Collection(TypedDict):
|
||||||
dimension: int
|
dimension: int
|
||||||
distance: DistanceType
|
distance: DistanceType
|
||||||
model: str
|
|
||||||
on_disk: NotRequired[bool]
|
on_disk: NotRequired[bool]
|
||||||
shards: NotRequired[int]
|
shards: NotRequired[int]
|
||||||
|
text: bool
|
||||||
|
multimodal: bool
|
||||||
|
|
||||||
|
|
||||||
ALL_COLLECTIONS: dict[str, Collection] = {
|
ALL_COLLECTIONS: dict[str, Collection] = {
|
||||||
"mail": {
|
"mail": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
"chat": {
|
"chat": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"git": {
|
"git": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
"book": {
|
"book": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
"blog": {
|
"blog": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"forum": {
|
"forum": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"text": {
|
"text": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
# Multimodal
|
|
||||||
"photo": {
|
"photo": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": False,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"comic": {
|
"comic": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": False,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"doc": {
|
"doc": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": False,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
# Observations
|
# Observations
|
||||||
"semantic": {
|
"semantic": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
"temporal": {
|
"temporal": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
TEXT_COLLECTIONS = {
|
TEXT_COLLECTIONS = {
|
||||||
coll
|
coll for coll, params in ALL_COLLECTIONS.items() if params.get("text")
|
||||||
for coll, params in ALL_COLLECTIONS.items()
|
|
||||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
|
||||||
}
|
}
|
||||||
MULTIMODAL_COLLECTIONS = {
|
MULTIMODAL_COLLECTIONS = {
|
||||||
coll
|
coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal")
|
||||||
for coll, params in ALL_COLLECTIONS.items()
|
|
||||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPES = {
|
TYPES = {
|
||||||
@ -119,5 +128,12 @@ def get_modality(mime_type: str) -> str:
|
|||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
def collection_model(collection: str) -> str | None:
|
def collection_model(
|
||||||
return ALL_COLLECTIONS.get(collection, {}).get("model")
|
collection: str, text: str, images: list[Image.Image]
|
||||||
|
) -> str | None:
|
||||||
|
config = ALL_COLLECTIONS.get(collection, {})
|
||||||
|
if images and config.get("multimodal"):
|
||||||
|
return settings.MIXED_EMBEDDING_MODEL
|
||||||
|
if text and config.get("text"):
|
||||||
|
return settings.TEXT_EMBEDDING_MODEL
|
||||||
|
return "unknown"
|
||||||
|
@ -205,7 +205,9 @@ class SourceItem(Base):
|
|||||||
id = Column(BigInteger, primary_key=True)
|
id = Column(BigInteger, primary_key=True)
|
||||||
modality = Column(Text, nullable=False)
|
modality = Column(Text, nullable=False)
|
||||||
sha256 = Column(BYTEA, nullable=False, unique=True)
|
sha256 = Column(BYTEA, nullable=False, unique=True)
|
||||||
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
|
inserted_at = Column(
|
||||||
|
DateTime(timezone=True), server_default=func.now(), default=func.now()
|
||||||
|
)
|
||||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
size = Column(Integer)
|
size = Column(Integer)
|
||||||
mime_type = Column(Text)
|
mime_type = Column(Text)
|
||||||
@ -261,14 +263,15 @@ class SourceItem(Base):
|
|||||||
images = [c for c in data.data if isinstance(c, Image.Image)]
|
images = [c for c in data.data if isinstance(c, Image.Image)]
|
||||||
image_names = image_filenames(chunk_id, images)
|
image_names = image_filenames(chunk_id, images)
|
||||||
|
|
||||||
|
modality = data.modality or cast(str, self.modality)
|
||||||
chunk = Chunk(
|
chunk = Chunk(
|
||||||
id=chunk_id,
|
id=chunk_id,
|
||||||
source=self,
|
source=self,
|
||||||
content=text or None,
|
content=text or None,
|
||||||
images=images,
|
images=images,
|
||||||
file_paths=image_names,
|
file_paths=image_names,
|
||||||
collection_name=data.collection_name or cast(str, self.modality),
|
collection_name=modality,
|
||||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
embedding_model=collections.collection_model(modality, text, images),
|
||||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||||
)
|
)
|
||||||
return chunk
|
return chunk
|
||||||
@ -284,5 +287,8 @@ class SourceItem(Base):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_contents(self) -> str | None:
|
def display_contents(self) -> str | dict | None:
|
||||||
return cast(str | None, self.content) or cast(str | None, self.filename)
|
return {
|
||||||
|
"tags": self.tags,
|
||||||
|
"size": self.size,
|
||||||
|
}
|
||||||
|
@ -570,13 +570,23 @@ class AgentObservation(SourceItem):
|
|||||||
"""Get all contradictions involving this observation."""
|
"""Get all contradictions involving this observation."""
|
||||||
return self.contradictions_as_first + self.contradictions_as_second
|
return self.contradictions_as_first + self.contradictions_as_second
|
||||||
|
|
||||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[extract.DataChunk]:
|
@property
|
||||||
|
def display_contents(self) -> dict:
|
||||||
|
return {
|
||||||
|
"subject": self.subject,
|
||||||
|
"content": self.content,
|
||||||
|
"observation_type": self.observation_type,
|
||||||
|
"evidence": self.evidence,
|
||||||
|
"confidence": self.confidence,
|
||||||
|
"agent_model": self.agent_model,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||||
"""
|
"""
|
||||||
Generate multiple chunks for different embedding dimensions.
|
Generate multiple chunks for different embedding dimensions.
|
||||||
Each chunk goes to a different Qdrant collection for specialized search.
|
Each chunk goes to a different Qdrant collection for specialized search.
|
||||||
"""
|
"""
|
||||||
chunks = []
|
|
||||||
|
|
||||||
# 1. Semantic chunk - standard content representation
|
# 1. Semantic chunk - standard content representation
|
||||||
semantic_text = observation.generate_semantic_text(
|
semantic_text = observation.generate_semantic_text(
|
||||||
cast(str, self.subject),
|
cast(str, self.subject),
|
||||||
@ -584,12 +594,10 @@ class AgentObservation(SourceItem):
|
|||||||
cast(str, self.content),
|
cast(str, self.content),
|
||||||
cast(observation.Evidence, self.evidence),
|
cast(observation.Evidence, self.evidence),
|
||||||
)
|
)
|
||||||
chunks.append(
|
semantic_chunk = extract.DataChunk(
|
||||||
extract.DataChunk(
|
|
||||||
data=[semantic_text],
|
data=[semantic_text],
|
||||||
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||||
collection_name="semantic",
|
modality="semantic",
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Temporal chunk - time-aware representation
|
# 2. Temporal chunk - time-aware representation
|
||||||
@ -599,13 +607,28 @@ class AgentObservation(SourceItem):
|
|||||||
cast(float, self.confidence),
|
cast(float, self.confidence),
|
||||||
cast(datetime, self.inserted_at),
|
cast(datetime, self.inserted_at),
|
||||||
)
|
)
|
||||||
chunks.append(
|
temporal_chunk = extract.DataChunk(
|
||||||
extract.DataChunk(
|
|
||||||
data=[temporal_text],
|
data=[temporal_text],
|
||||||
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
|
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
|
||||||
collection_name="temporal",
|
modality="temporal",
|
||||||
|
)
|
||||||
|
|
||||||
|
others = [
|
||||||
|
self._make_chunk(
|
||||||
|
extract.DataChunk(
|
||||||
|
data=[i],
|
||||||
|
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||||
|
modality="semantic",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
for i in [
|
||||||
|
self.content,
|
||||||
|
self.subject,
|
||||||
|
self.observation_type,
|
||||||
|
self.evidence.get("quote", ""),
|
||||||
|
]
|
||||||
|
if i
|
||||||
|
]
|
||||||
|
|
||||||
# TODO: Add more embedding dimensions here:
|
# TODO: Add more embedding dimensions here:
|
||||||
# 3. Epistemic chunk - belief structure focused
|
# 3. Epistemic chunk - belief structure focused
|
||||||
@ -632,4 +655,7 @@ class AgentObservation(SourceItem):
|
|||||||
# collection_name="observations_relational"
|
# collection_name="observations_relational"
|
||||||
# ))
|
# ))
|
||||||
|
|
||||||
return chunks
|
return [
|
||||||
|
self._make_chunk(semantic_chunk),
|
||||||
|
self._make_chunk(temporal_chunk),
|
||||||
|
] + others
|
||||||
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def embed_chunks(
|
def embed_chunks(
|
||||||
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
chunks: list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
@ -24,52 +24,50 @@ def embed_chunks(
|
|||||||
vo = voyageai.Client() # type: ignore
|
vo = voyageai.Client() # type: ignore
|
||||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||||
return vo.multimodal_embed(
|
return vo.multimodal_embed(
|
||||||
chunks, # type: ignore
|
chunks,
|
||||||
model=model,
|
model=model,
|
||||||
input_type=input_type,
|
input_type=input_type,
|
||||||
).embeddings
|
).embeddings
|
||||||
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
|
|
||||||
|
texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks]
|
||||||
|
return cast(
|
||||||
|
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def break_chunk(
|
||||||
|
chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS
|
||||||
|
) -> list[extract.MulitmodalChunk]:
|
||||||
|
result = []
|
||||||
|
for c in chunk:
|
||||||
|
if isinstance(c, str):
|
||||||
|
result += chunk_text(c, chunk_size, OVERLAP_TOKENS)
|
||||||
|
else:
|
||||||
|
result.append(chunk)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def embed_text(
|
def embed_text(
|
||||||
texts: list[str],
|
chunks: list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
chunks = [
|
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks]
|
||||||
c
|
if not any(chunked_chunks):
|
||||||
for text in texts
|
|
||||||
if isinstance(text, str)
|
|
||||||
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
|
|
||||||
if c.strip()
|
|
||||||
]
|
|
||||||
if not chunks:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
return embed_chunks(chunked_chunks, model, input_type)
|
||||||
return embed_chunks(chunks, model, input_type)
|
|
||||||
except voyageai.error.InvalidRequestError as e: # type: ignore
|
|
||||||
logger.error(f"Error embedding text: {e}")
|
|
||||||
logger.debug(f"Text: {texts}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def embed_mixed(
|
def embed_mixed(
|
||||||
items: list[extract.MulitmodalChunk],
|
items: list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
chunked_chunks = [break_chunk(item, chunk_size) for item in items]
|
||||||
if isinstance(item, str):
|
return embed_chunks(chunked_chunks, model, input_type)
|
||||||
return [
|
|
||||||
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
|
|
||||||
]
|
|
||||||
return [item]
|
|
||||||
|
|
||||||
chunks = [c for item in items for c in to_chunks(item)]
|
|
||||||
return embed_chunks([chunks], model, input_type)
|
|
||||||
|
|
||||||
|
|
||||||
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||||
@ -87,6 +85,9 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
|||||||
|
|
||||||
def embed_source_item(item: SourceItem) -> list[Chunk]:
|
def embed_source_item(item: SourceItem) -> list[Chunk]:
|
||||||
chunks = list(item.data_chunks())
|
chunks = list(item.data_chunks())
|
||||||
|
logger.error(
|
||||||
|
f"Embedding source item: {item.id} - {[(c.embedding_model, c.collection_name, c.chunks) for c in chunks]}"
|
||||||
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ class DataChunk:
|
|||||||
data: Sequence[MulitmodalChunk]
|
data: Sequence[MulitmodalChunk]
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
mime_type: str = "text/plain"
|
mime_type: str = "text/plain"
|
||||||
collection_name: str | None = None
|
modality: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -8,7 +8,7 @@ class Evidence(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
def generate_semantic_text(
|
def generate_semantic_text(
|
||||||
subject: str, observation_type: str, content: str, evidence: Evidence
|
subject: str, observation_type: str, content: str, evidence: Evidence | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate text optimized for semantic similarity search."""
|
"""Generate text optimized for semantic similarity search."""
|
||||||
parts = [
|
parts = [
|
||||||
@ -54,8 +54,9 @@ def generate_temporal_text(
|
|||||||
f"Time: {time_of_day} on {day_of_week} ({time_period})",
|
f"Time: {time_of_day} on {day_of_week} ({time_period})",
|
||||||
f"Subject: {subject}",
|
f"Subject: {subject}",
|
||||||
f"Observation: {content}",
|
f"Observation: {content}",
|
||||||
f"Confidence: {confidence}",
|
|
||||||
]
|
]
|
||||||
|
if confidence:
|
||||||
|
parts.append(f"Confidence: {confidence}")
|
||||||
|
|
||||||
return " | ".join(parts)
|
return " | ".join(parts)
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ logger = logging.getLogger(__name__)
|
|||||||
TAGS_PROMPT = """
|
TAGS_PROMPT = """
|
||||||
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
|
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
|
||||||
|
|
||||||
|
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||||
|
|
||||||
Return your response as JSON with this format:
|
Return your response as JSON with this format:
|
||||||
{{
|
{{
|
||||||
"summary": "{summary}",
|
"summary": "{summary}",
|
||||||
@ -23,6 +25,8 @@ SUMMARY_PROMPT = """
|
|||||||
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
|
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
|
||||||
Also provide 3-5 relevant tags that capture the main topics or themes.
|
Also provide 3-5 relevant tags that capture the main topics or themes.
|
||||||
|
|
||||||
|
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||||
|
|
||||||
Return your response as JSON with this format:
|
Return your response as JSON with this format:
|
||||||
{{
|
{{
|
||||||
"summary": "your summary here",
|
"summary": "your summary here",
|
||||||
|
@ -1,62 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
from fastapi import UploadFile
|
|
||||||
import httpx
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
|
||||||
|
|
||||||
SERVER = "http://localhost:8000"
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
mcp = FastMCP("memory")
|
|
||||||
|
|
||||||
|
|
||||||
async def make_request(
|
|
||||||
path: str,
|
|
||||||
method: str,
|
|
||||||
data: dict | None = None,
|
|
||||||
json: dict | None = None,
|
|
||||||
files: list[UploadFile] | None = None,
|
|
||||||
) -> httpx.Response:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
return await client.request(
|
|
||||||
method,
|
|
||||||
f"{SERVER}/{path}",
|
|
||||||
data=data,
|
|
||||||
json=json,
|
|
||||||
files=files, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
|
|
||||||
return await make_request(path, "POST", data=data)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def search(
|
|
||||||
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
logger.error(f"Searching for {query}")
|
|
||||||
resp = await post_data(
|
|
||||||
"search",
|
|
||||||
{
|
|
||||||
"query": query,
|
|
||||||
"previews": previews,
|
|
||||||
"modalities": modalities,
|
|
||||||
"limit": limit,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Initialize and run the server
|
|
||||||
args = argparse.ArgumentParser()
|
|
||||||
args.add_argument("--server", type=str)
|
|
||||||
args = args.parse_args()
|
|
||||||
|
|
||||||
SERVER = args.server
|
|
||||||
|
|
||||||
mcp.run(transport=args.transport)
|
|
Loading…
x
Reference in New Issue
Block a user