Compare commits

...

3 Commits

Author SHA1 Message Date
Daniel O'Connell
b10a1fb130 initial memory tools 2025-06-01 00:11:21 +02:00
Daniel O'Connell
1dd93929c1 Add embedding for observations 2025-05-31 16:51:55 +02:00
Daniel O'Connell
004bd39987 Add observations model 2025-05-31 16:15:30 +02:00
28 changed files with 3285 additions and 992 deletions

View File

@ -12,6 +12,32 @@ from alembic import context
from memory.common import settings
from memory.common.db.models import Base
# Import all models to ensure they're registered with Base.metadata
from memory.common.db.models import (
SourceItem,
Chunk,
MailMessage,
EmailAttachment,
ChatMessage,
BlogPost,
Comic,
BookSection,
ForumPost,
GithubItem,
GitCommit,
Photo,
MiscDoc,
AgentObservation,
ObservationContradiction,
ReactionPattern,
ObservationPattern,
BeliefCluster,
ConversationMetrics,
Book,
ArticleFeed,
EmailAccount,
)
# this is the Alembic Config object
config = context.config

View File

@ -0,0 +1,279 @@
"""Add observation models
Revision ID: 6554eb260176
Revises: 2524646f56f6
Create Date: 2025-05-31 15:49:47.579256
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "6554eb260176"
down_revision: Union[str, None] = "2524646f56f6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"belief_cluster",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("cluster_name", sa.Text(), nullable=False),
sa.Column("core_beliefs", sa.ARRAY(sa.Text()), nullable=False),
sa.Column("peripheral_beliefs", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"internal_consistency", sa.Numeric(precision=3, scale=2), nullable=True
),
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"last_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"cluster_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"belief_cluster_consistency_idx",
"belief_cluster",
["internal_consistency"],
unique=False,
)
op.create_index(
"belief_cluster_name_idx", "belief_cluster", ["cluster_name"], unique=False
)
op.create_table(
"conversation_metrics",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("session_id", sa.UUID(), nullable=False),
sa.Column("depth_score", sa.Numeric(precision=3, scale=2), nullable=True),
sa.Column("breakthrough_count", sa.Integer(), nullable=True),
sa.Column(
"challenge_acceptance", sa.Numeric(precision=3, scale=2), nullable=True
),
sa.Column("new_insights", sa.Integer(), nullable=True),
sa.Column("user_engagement", sa.Numeric(precision=3, scale=2), nullable=True),
sa.Column("duration_minutes", sa.Integer(), nullable=True),
sa.Column("observation_count", sa.Integer(), nullable=True),
sa.Column("contradiction_count", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"metrics_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"conv_metrics_breakthrough_idx",
"conversation_metrics",
["breakthrough_count"],
unique=False,
)
op.create_index(
"conv_metrics_depth_idx", "conversation_metrics", ["depth_score"], unique=False
)
op.create_index(
"conv_metrics_session_idx", "conversation_metrics", ["session_id"], unique=True
)
op.create_table(
"observation_pattern",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("pattern_type", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=False),
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=False),
sa.Column("exceptions", sa.ARRAY(sa.BigInteger()), nullable=True),
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
sa.Column("validity_start", sa.DateTime(timezone=True), nullable=True),
sa.Column("validity_end", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"pattern_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"obs_pattern_confidence_idx",
"observation_pattern",
["confidence"],
unique=False,
)
op.create_index(
"obs_pattern_type_idx", "observation_pattern", ["pattern_type"], unique=False
)
op.create_index(
"obs_pattern_validity_idx",
"observation_pattern",
["validity_start", "validity_end"],
unique=False,
)
op.create_table(
"reaction_pattern",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("trigger_type", sa.Text(), nullable=False),
sa.Column("reaction_type", sa.Text(), nullable=False),
sa.Column("frequency", sa.Numeric(precision=3, scale=2), nullable=False),
sa.Column(
"first_observed",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"last_observed",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("example_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
sa.Column(
"reaction_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"reaction_frequency_idx", "reaction_pattern", ["frequency"], unique=False
)
op.create_index(
"reaction_trigger_idx", "reaction_pattern", ["trigger_type"], unique=False
)
op.create_index(
"reaction_type_idx", "reaction_pattern", ["reaction_type"], unique=False
)
op.create_table(
"agent_observation",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("session_id", sa.UUID(), nullable=True),
sa.Column("observation_type", sa.Text(), nullable=False),
sa.Column("subject", sa.Text(), nullable=False),
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
sa.Column("evidence", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("agent_model", sa.Text(), nullable=False),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"agent_obs_confidence_idx", "agent_observation", ["confidence"], unique=False
)
op.create_index(
"agent_obs_model_idx", "agent_observation", ["agent_model"], unique=False
)
op.create_index(
"agent_obs_session_idx", "agent_observation", ["session_id"], unique=False
)
op.create_index(
"agent_obs_subject_idx", "agent_observation", ["subject"], unique=False
)
op.create_index(
"agent_obs_type_idx", "agent_observation", ["observation_type"], unique=False
)
op.create_table(
"observation_contradiction",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("observation_1_id", sa.BigInteger(), nullable=False),
sa.Column("observation_2_id", sa.BigInteger(), nullable=False),
sa.Column("contradiction_type", sa.Text(), nullable=False),
sa.Column(
"detected_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("detection_method", sa.Text(), nullable=False),
sa.Column("resolution", sa.Text(), nullable=True),
sa.Column(
"observation_metadata",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
sa.ForeignKeyConstraint(
["observation_1_id"], ["agent_observation.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["observation_2_id"], ["agent_observation.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"obs_contra_method_idx",
"observation_contradiction",
["detection_method"],
unique=False,
)
op.create_index(
"obs_contra_obs1_idx",
"observation_contradiction",
["observation_1_id"],
unique=False,
)
op.create_index(
"obs_contra_obs2_idx",
"observation_contradiction",
["observation_2_id"],
unique=False,
)
op.create_index(
"obs_contra_type_idx",
"observation_contradiction",
["contradiction_type"],
unique=False,
)
op.add_column("chunk", sa.Column("collection_name", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("chunk", "collection_name")
op.drop_index("obs_contra_type_idx", table_name="observation_contradiction")
op.drop_index("obs_contra_obs2_idx", table_name="observation_contradiction")
op.drop_index("obs_contra_obs1_idx", table_name="observation_contradiction")
op.drop_index("obs_contra_method_idx", table_name="observation_contradiction")
op.drop_table("observation_contradiction")
op.drop_index("agent_obs_type_idx", table_name="agent_observation")
op.drop_index("agent_obs_subject_idx", table_name="agent_observation")
op.drop_index("agent_obs_session_idx", table_name="agent_observation")
op.drop_index("agent_obs_model_idx", table_name="agent_observation")
op.drop_index("agent_obs_confidence_idx", table_name="agent_observation")
op.drop_table("agent_observation")
op.drop_index("reaction_type_idx", table_name="reaction_pattern")
op.drop_index("reaction_trigger_idx", table_name="reaction_pattern")
op.drop_index("reaction_frequency_idx", table_name="reaction_pattern")
op.drop_table("reaction_pattern")
op.drop_index("obs_pattern_validity_idx", table_name="observation_pattern")
op.drop_index("obs_pattern_type_idx", table_name="observation_pattern")
op.drop_index("obs_pattern_confidence_idx", table_name="observation_pattern")
op.drop_table("observation_pattern")
op.drop_index("conv_metrics_session_idx", table_name="conversation_metrics")
op.drop_index("conv_metrics_depth_idx", table_name="conversation_metrics")
op.drop_index("conv_metrics_breakthrough_idx", table_name="conversation_metrics")
op.drop_table("conversation_metrics")
op.drop_index("belief_cluster_name_idx", table_name="belief_cluster")
op.drop_index("belief_cluster_consistency_idx", table_name="belief_cluster")
op.drop_table("belief_cluster")

View File

@ -3,3 +3,4 @@ uvicorn==0.29.0
python-jose==3.3.0
python-multipart==0.0.9
sqladmin
mcp==1.9.2

View File

@ -5,4 +5,6 @@ alembic==1.13.1
dotenv==0.9.9
voyageai==0.3.2
qdrant-client==1.9.0
anthropic==0.18.1
anthropic==0.18.1
bm25s[full]==0.2.13

654
src/memory/api/MCP/tools.py Normal file
View File

@ -0,0 +1,654 @@
"""
MCP tools for the epistemic sparring partner system.
"""
from datetime import datetime, timezone
from hashlib import sha256
import logging
import uuid
from typing import cast
from mcp.server.fastmcp import FastMCP
from memory.common.db.models.source_item import SourceItem
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy import func, cast as sql_cast, Text
from memory.common.db.connection import make_session
from memory.common import extract
from memory.common.db.models import AgentObservation
from memory.api.search import search, SearchFilters
from memory.common.formatters import observation
from memory.workers.tasks.content_processing import process_content_item
logger = logging.getLogger(__name__)
# Create MCP server instance
mcp = FastMCP("memory", stateless=True)
@mcp.tool()
async def get_all_tags() -> list[str]:
"""
Get all unique tags used across the entire knowledge base.
Purpose:
This tool retrieves all tags that have been used in the system, both from
AI observations (created with 'observe') and other content. Use it to
understand the tag taxonomy, ensure consistency, or discover related topics.
When to use:
- Before creating new observations, to use consistent tag naming
- To explore what topics/contexts have been tracked
- To build tag filters for search operations
- To understand the user's areas of interest
- For tag autocomplete or suggestion features
Returns:
Sorted list of all unique tags in the system. Tags follow patterns like:
- Topics: "machine-learning", "functional-programming"
- Projects: "project:website-redesign"
- Contexts: "context:work", "context:late-night"
- Domains: "domain:finance"
Example:
# Get all tags to ensure consistency
tags = await get_all_tags()
# Returns: ["ai-safety", "context:work", "functional-programming",
# "machine-learning", "project:thesis", ...]
# Use to check if a topic has been discussed before
if "quantum-computing" in tags:
# Search for related observations
observations = await search_observations(
query="quantum computing",
tags=["quantum-computing"]
)
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@mcp.tool()
async def get_all_subjects() -> list[str]:
"""
Get all unique subjects from observations about the user.
Purpose:
This tool retrieves all subject identifiers that have been used in
observations (created with 'observe'). Subjects are the consistent
identifiers for what observations are about. Use this to understand
what aspects of the user have been tracked and ensure consistency.
When to use:
- Before creating new observations, to use existing subject names
- To discover what aspects of the user have been observed
- To build subject filters for targeted searches
- To ensure consistent naming across observations
- To get an overview of the user model
Returns:
Sorted list of all unique subjects. Common patterns include:
- "programming_style", "programming_philosophy"
- "work_habits", "work_schedule"
- "ai_beliefs", "ai_safety_beliefs"
- "learning_preferences"
- "communication_style"
Example:
# Get all subjects to ensure consistency
subjects = await get_all_subjects()
# Returns: ["ai_safety_beliefs", "architecture_preferences",
# "programming_philosophy", "work_schedule", ...]
# Use to check what we know about the user
if "programming_style" in subjects:
# Get all programming-related observations
observations = await search_observations(
query="programming",
subject="programming_style"
)
Best practices:
- Always check existing subjects before creating new ones
- Use snake_case for consistency
- Be specific but not too granular
- Group related observations under same subject
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@mcp.tool()
async def get_all_observation_types() -> list[str]:
"""
Get all unique observation types that have been used.
Purpose:
This tool retrieves the distinct observation types that have been recorded
in the system. While the standard types are predefined (belief, preference,
behavior, contradiction, general), this shows what's actually been used.
Helpful for understanding the distribution of observation types.
When to use:
- To see what types of observations have been made
- To understand the balance of different observation types
- To check if all standard types are being utilized
- For analytics or reporting on observation patterns
Standard types:
- "belief": Opinions or beliefs the user holds
- "preference": Things they prefer or favor
- "behavior": Patterns in how they act or work
- "contradiction": Noted inconsistencies
- "general": Observations that don't fit other categories
Returns:
List of observation types that have actually been used in the system.
Example:
# Check what types of observations exist
types = await get_all_observation_types()
# Returns: ["behavior", "belief", "contradiction", "preference"]
# Use to analyze observation distribution
for obs_type in types:
observations = await search_observations(
query="",
observation_types=[obs_type],
limit=100
)
print(f"{obs_type}: {len(observations)} observations")
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)
@mcp.tool()
async def search_knowledge_base(
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
) -> list[dict]:
"""
Search through the user's stored knowledge and content.
Purpose:
This tool searches the user's personal knowledge base - a collection of
their saved content including emails, documents, blog posts, books, and
more. Use this alongside 'search_observations' to build a complete picture:
- search_knowledge_base: Finds user's actual content and information
- search_observations: Finds AI-generated insights about the user
Together they enable deeply personalized, context-aware assistance.
When to use:
- User asks about something they've read/written/received
- You need to find specific content the user has saved
- User references a document, email, or article
- To provide quotes or information from user's sources
- To understand context from user's past communications
- When user says "that article about..." or similar references
How it works:
Uses hybrid search combining semantic understanding with keyword matching.
This means it finds content based on meaning AND specific terms, giving
you the best of both approaches. Results are ranked by relevance.
Args:
query: Natural language search query. Be descriptive about what you're
looking for. The search understands meaning but also values exact terms.
Examples:
- "email about project deadline from last week"
- "functional programming articles comparing Haskell and Scala"
- "that blog post about AI safety and alignment"
- "recipe for chocolate cake Sarah sent me"
Pro tip: Include both concepts and specific keywords for best results.
previews: Whether to include content snippets in results.
- True: Returns preview text and image previews (useful for quick scanning)
- False: Returns just metadata (faster, less data)
Default is False.
modalities: Types of content to search. Leave empty to search all.
Available types:
- 'email': Email messages
- 'blog': Blog posts and articles
- 'book': Book sections and ebooks
- 'forum': Forum posts (e.g., LessWrong, Reddit)
- 'observation': AI observations (use search_observations instead)
- 'photo': Images with extracted text
- 'comic': Comics and graphic content
- 'webpage': General web pages
Examples:
- ["email"] - only emails
- ["blog", "forum"] - articles and forum posts
- [] - search everything
limit: Maximum results to return (1-100). Default 10.
Increase for comprehensive searches, decrease for quick lookups.
Returns:
List of search results ranked by relevance, each containing:
- id: Unique identifier for the source item
- score: Relevance score (0-1, higher is better)
- chunks: Matching content segments with metadata
- content: Full details including:
- For emails: sender, recipient, subject, date
- For blogs: author, title, url, publish date
- For books: title, author, chapter info
- Type-specific fields for each modality
- filename: Path to file if content is stored on disk
Examples:
# Find specific email
results = await search_knowledge_base(
query="Sarah deadline project proposal next Friday",
modalities=["email"],
previews=True,
limit=5
)
# Search for technical articles
results = await search_knowledge_base(
query="functional programming monads category theory",
modalities=["blog", "book"],
limit=20
)
# Find everything about a topic
results = await search_knowledge_base(
query="machine learning deployment kubernetes docker",
previews=True
)
# Quick lookup of a remembered document
results = await search_knowledge_base(
query="tax forms 2023 accountant recommendations",
modalities=["email"],
limit=3
)
Best practices:
- Include context in queries ("email from Sarah" vs just "Sarah")
- Use modalities to filter when you know the content type
- Enable previews when you need to verify content before using
- Combine with search_observations for complete context
- Higher scores (>0.7) indicate strong matches
- If no results, try broader queries or different phrasing
"""
logger.info(f"MCP search for: {query}")
upload_data = extract.extract_text(query)
results = await search(
upload_data,
previews=previews,
modalities=modalities,
limit=limit,
min_text_score=0.3,
min_multimodal_score=0.3,
)
# Convert SearchResult objects to dictionaries for MCP
return [result.model_dump() for result in results]
@mcp.tool()
async def observe(
content: str,
subject: str,
observation_type: str = "general",
confidence: float = 0.8,
evidence: dict | None = None,
tags: list[str] | None = None,
session_id: str | None = None,
agent_model: str = "unknown",
) -> dict:
"""
Record an observation about the user to build long-term understanding.
Purpose:
This tool is part of a memory system designed to help AI agents build a
deep, persistent understanding of users over time. Use it to record any
notable information about the user's preferences, beliefs, behaviors, or
characteristics. These observations accumulate to create a comprehensive
model of the user that improves future interactions.
Quick Reference:
# Most common patterns:
observe(content="User prefers X over Y because...", subject="preferences", observation_type="preference")
observe(content="User always/often does X when Y", subject="work_habits", observation_type="behavior")
observe(content="User believes/thinks X about Y", subject="beliefs_on_topic", observation_type="belief")
observe(content="User said X but previously said Y", subject="topic", observation_type="contradiction")
When to use:
- User expresses a preference or opinion
- You notice a behavioral pattern
- User reveals information about their work/life/interests
- You spot a contradiction with previous statements
- Any insight that would help understand the user better in future
Important: Be an active observer. Don't wait to be asked - proactively record
observations throughout conversations to build understanding.
Args:
content: The observation itself. Be specific and detailed. Write complete
thoughts that will make sense when read months later without context.
Bad: "Likes FP"
Good: "User strongly prefers functional programming paradigms, especially
pure functions and immutability, considering them more maintainable"
subject: A consistent identifier for what this observation is about. Use
snake_case and be consistent across observations to enable tracking.
Examples:
- "programming_style" (not "coding" or "development")
- "work_habits" (not "productivity" or "work_patterns")
- "ai_safety_beliefs" (not "AI" or "artificial_intelligence")
observation_type: Categorize the observation:
- "belief": An opinion or belief the user holds
- "preference": Something they prefer or favor
- "behavior": A pattern in how they act or work
- "contradiction": An inconsistency with previous observations
- "general": Doesn't fit other categories
confidence: How certain you are (0.0-1.0):
- 1.0: User explicitly stated this
- 0.9: Strongly implied or demonstrated repeatedly
- 0.8: Inferred with high confidence (default)
- 0.7: Probable but with some uncertainty
- 0.6 or below: Speculative, use sparingly
evidence: Supporting context as a dict. Include relevant details:
- "quote": Exact words from the user
- "context": What prompted this observation
- "timestamp": When this was observed
- "related_to": Connection to other topics
Example: {
"quote": "I always refactor to pure functions",
"context": "Discussing code review practices"
}
tags: Categorization labels. Use lowercase with hyphens. Common patterns:
- Topics: "machine-learning", "web-development", "philosophy"
- Projects: "project:website-redesign", "project:thesis"
- Contexts: "context:work", "context:personal", "context:late-night"
- Domains: "domain:finance", "domain:healthcare"
session_id: UUID string to group observations from the same conversation.
Generate one UUID per conversation and reuse it for all observations
in that conversation. Format: "550e8400-e29b-41d4-a716-446655440000"
agent_model: Which AI model made this observation (e.g., "claude-3-opus",
"gpt-4", "claude-3.5-sonnet"). Helps track observation quality.
Returns:
Dict with created observation details:
- id: Unique identifier for reference
- created_at: Timestamp of creation
- subject: The subject as stored
- observation_type: The type as stored
- confidence: The confidence score
- tags: List of applied tags
Examples:
# After user mentions their coding philosophy
await observe(
content="User believes strongly in functional programming principles, "
"particularly avoiding mutable state which they call 'the root "
"of all evil'. They prioritize code purity over performance.",
subject="programming_philosophy",
observation_type="belief",
confidence=0.95,
evidence={
"quote": "State is the root of all evil in programming",
"context": "Discussing why they chose Haskell for their project"
},
tags=["programming", "functional-programming", "philosophy"],
session_id="550e8400-e29b-41d4-a716-446655440000",
agent_model="claude-3-opus"
)
# Noticing a work pattern
await observe(
content="User frequently works on complex problems late at night, "
"typically between 11pm and 3am, claiming better focus",
subject="work_schedule",
observation_type="behavior",
confidence=0.85,
evidence={
"context": "Mentioned across multiple conversations over 2 weeks"
},
tags=["behavior", "work-habits", "productivity", "context:late-night"],
agent_model="claude-3-opus"
)
# Recording a contradiction
await observe(
content="User now advocates for microservices architecture, but "
"previously argued strongly for monoliths in similar contexts",
subject="architecture_preferences",
observation_type="contradiction",
confidence=0.9,
evidence={
"quote": "Microservices are definitely the way to go",
"context": "Designing a new system similar to one from 3 months ago"
},
tags=["architecture", "contradiction", "software-design"],
agent_model="gpt-4"
)
"""
# Create the observation
observation = AgentObservation(
content=content,
subject=subject,
observation_type=observation_type,
confidence=confidence,
evidence=evidence,
tags=tags or [],
session_id=uuid.UUID(session_id) if session_id else None,
agent_model=agent_model,
size=len(content),
mime_type="text/plain",
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(),
inserted_at=datetime.now(timezone.utc),
)
try:
with make_session() as session:
process_content_item(observation, session)
if not observation.id:
raise ValueError("Observation not created")
logger.info(
f"Observation created: {observation.id}, {observation.inserted_at}"
)
return {
"id": observation.id,
"created_at": observation.inserted_at.isoformat(),
"subject": observation.subject,
"observation_type": observation.observation_type,
"confidence": cast(float, observation.confidence),
"tags": observation.tags,
}
except Exception as e:
logger.error(f"Error creating observation: {e}")
raise
@mcp.tool()
async def search_observations(
query: str,
subject: str = "",
tags: list[str] | None = None,
observation_types: list[str] | None = None,
min_confidence: float = 0.5,
limit: int = 10,
) -> list[dict]:
"""
Search through observations to understand the user better.
Purpose:
This tool searches through all observations recorded about the user using
the 'observe' tool. Use it to recall past insights, check for patterns,
find contradictions, or understand the user's preferences before responding.
The more you use this tool, the more personalized and insightful your
responses can be.
When to use:
- Before answering questions where user preferences might matter
- When the user references something from the past
- To check if current behavior aligns with past patterns
- To find related observations on a topic
- To build context about the user's expertise or interests
- Whenever personalization would improve your response
How it works:
Uses hybrid search combining semantic similarity with keyword matching.
Searches across multiple embedding spaces (semantic meaning and temporal
context) to find relevant observations from different angles. This approach
ensures you find both conceptually related and specifically mentioned items.
Args:
query: Natural language description of what you're looking for. The search
matches both meaning and specific terms in observation content.
Examples:
- "programming preferences and coding style"
- "opinions about artificial intelligence and AI safety"
- "work habits productivity patterns when does user work best"
- "previous projects the user has worked on"
Pro tip: Use natural language but include key terms you expect to find.
subject: Filter by exact subject identifier. Must match subjects used when
creating observations (e.g., "programming_style", "work_habits").
Leave empty to search all subjects. Use this when you know the exact
subject category you want.
tags: Filter results to only observations with these tags. Observations must
have at least one matching tag. Use the same format as when creating:
- ["programming", "functional-programming"]
- ["context:work", "project:thesis"]
- ["domain:finance", "machine-learning"]
observation_types: Filter by type of observation:
- "belief": Opinions or beliefs the user holds
- "preference": Things they prefer or favor
- "behavior": Patterns in how they act or work
- "contradiction": Noted inconsistencies
- "general": Other observations
Leave as None to search all types.
min_confidence: Only return observations with confidence >= this value.
- Use 0.8+ for high-confidence facts
- Use 0.5-0.7 to include inferred observations
- Default 0.5 includes most observations
Range: 0.0 to 1.0
limit: Maximum results to return (1-100). Default 10. Increase when you
need comprehensive understanding of a topic.
Returns:
List of observations sorted by relevance, each containing:
- subject: What the observation is about
- content: The full observation text
- observation_type: Type of observation
- evidence: Supporting context/quotes if provided
- confidence: How certain the observation is (0-1)
- agent_model: Which AI model made the observation
- tags: All tags on this observation
- created_at: When it was observed (if available)
Examples:
# Before discussing code architecture
results = await search_observations(
query="software architecture preferences microservices monoliths",
tags=["architecture"],
min_confidence=0.7
)
# Understanding work style for scheduling
results = await search_observations(
query="when does user work best productivity schedule",
observation_types=["behavior", "preference"],
subject="work_schedule"
)
# Check for AI safety views before discussing AI
results = await search_observations(
query="artificial intelligence safety alignment concerns",
observation_types=["belief"],
min_confidence=0.8,
limit=20
)
# Find contradictions on a topic
results = await search_observations(
query="testing methodology unit tests integration",
observation_types=["contradiction"],
tags=["testing", "software-development"]
)
Best practices:
- Search before making assumptions about user preferences
- Use broad queries first, then filter with tags/types if too many results
- Check for contradictions when user says something unexpected
- Higher confidence observations are more reliable
- Recent observations may override older ones on same topic
"""
source_ids = None
if tags or observation_types:
with make_session() as session:
items_query = session.query(AgentObservation.id)
if tags:
# Use PostgreSQL array overlap operator with proper array casting
items_query = items_query.filter(
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
if observation_types:
items_query = items_query.filter(
AgentObservation.observation_type.in_(observation_types)
)
source_ids = [item.id for item in items_query.all()]
if not source_ids:
return []
semantic_text = observation.generate_semantic_text(
subject=subject or "",
observation_type="".join(observation_types or []),
content=query,
evidence=None,
)
temporal = observation.generate_temporal_text(
subject=subject or "",
content=query,
confidence=0,
created_at=datetime.now(timezone.utc),
)
results = await search(
[
extract.DataChunk(data=[query]),
extract.DataChunk(data=[semantic_text]),
extract.DataChunk(data=[temporal]),
],
previews=True,
modalities=["semantic", "temporal"],
limit=limit,
min_text_score=0.8,
filters=SearchFilters(
subject=subject,
confidence=min_confidence,
tags=tags,
observation_types=observation_types,
source_ids=source_ids,
),
)
return [
cast(dict, cast(dict, result.model_dump()).get("content")) for result in results
]

View File

@ -18,6 +18,7 @@ from memory.common.db.models import (
ArticleFeed,
EmailAccount,
ForumPost,
AgentObservation,
)
@ -53,6 +54,7 @@ class SourceItemAdmin(ModelView, model=SourceItem):
class ChunkAdmin(ModelView, model=Chunk):
column_list = ["id", "source_id", "embedding_model", "created_at"]
column_sortable_list = ["created_at"]
class MailMessageAdmin(ModelView, model=MailMessage):
@ -174,9 +176,23 @@ class EmailAccountAdmin(ModelView, model=EmailAccount):
column_searchable_list = ["name", "email_address"]
class AgentObservationAdmin(ModelView, model=AgentObservation):
column_list = [
"id",
"content",
"subject",
"observation_type",
"confidence",
"evidence",
"inserted_at",
]
column_searchable_list = ["subject", "observation_type"]
def setup_admin(admin: Admin):
"""Add all admin views to the admin instance."""
admin.add_view(SourceItemAdmin)
admin.add_view(AgentObservationAdmin)
admin.add_view(ChunkAdmin)
admin.add_view(EmailAccountAdmin)
admin.add_view(MailMessageAdmin)

View File

@ -2,6 +2,7 @@
FastAPI application for the knowledge base.
"""
import contextlib
import pathlib
import logging
from typing import Annotated, Optional
@ -15,15 +16,25 @@ from memory.common import extract
from memory.common.db.connection import get_engine
from memory.api.admin import setup_admin
from memory.api.search import search, SearchResult
from memory.api.MCP.tools import mcp
logger = logging.getLogger(__name__)
app = FastAPI(title="Knowledge Base API")
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(mcp.session_manager.run())
yield
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
# SQLAdmin setup
engine = get_engine()
admin = Admin(app, engine)
setup_admin(admin)
app.mount("/", mcp.streamable_http_app())
@app.get("/health")
@ -104,4 +115,7 @@ def main():
if __name__ == "__main__":
from memory.common.qdrant import setup_qdrant
setup_qdrant()
main()

View File

@ -2,21 +2,29 @@
Search endpoints for the knowledge base API.
"""
import asyncio
import base64
from hashlib import sha256
import io
from collections import defaultdict
from typing import Callable, Optional
import logging
from collections import defaultdict
from typing import Any, Callable, Optional, TypedDict, NotRequired
import bm25s
import Stemmer
import qdrant_client
from PIL import Image
from pydantic import BaseModel
import qdrant_client
from qdrant_client.http import models as qdrant_models
from memory.common import embedding, qdrant, extract, settings
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
from memory.common import embedding, extract, qdrant, settings
from memory.common.collections import (
ALL_COLLECTIONS,
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
@ -28,6 +36,17 @@ class AnnotatedChunk(BaseModel):
preview: Optional[str | None] = None
class SourceData(BaseModel):
"""Holds source item data to avoid SQLAlchemy session issues"""
id: int
size: int | None
mime_type: str | None
filename: str | None
content: str | dict | None
content_length: int
class SearchResponse(BaseModel):
collection: str
results: list[dict]
@ -38,16 +57,46 @@ class SearchResult(BaseModel):
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str] = None
content: Optional[str | dict] = None
filename: Optional[str] = None
class SearchFilters(TypedDict):
subject: NotRequired[str | None]
confidence: NotRequired[float]
tags: NotRequired[list[str] | None]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]
async def with_timeout(
call, timeout: int = 2
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Run a function with a timeout.
Args:
call: The function to run
timeout: The timeout in seconds
"""
try:
return await asyncio.wait_for(call, timeout=timeout)
except TimeoutError:
logger.warning(f"Search timed out after {timeout}s")
return []
except Exception as e:
logger.error(f"Search failed: {e}")
return []
def annotated_chunk(
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
) -> tuple[SourceItem, AnnotatedChunk]:
) -> tuple[SourceData, AnnotatedChunk]:
def serialize_item(item: bytes | str | Image.Image) -> str | None:
if not previews and not isinstance(item, str):
return None
if not previews and isinstance(item, str):
return item[:100]
if isinstance(item, Image.Image):
buffer = io.BytesIO()
@ -68,7 +117,19 @@ def annotated_chunk(
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
return chunk.source, AnnotatedChunk(
# Prefetch all needed source data while in session
source = chunk.source
source_data = SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content=source.display_contents,
content_length=len(source.content) if source.content else 0,
)
return source_data, AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
@ -76,24 +137,28 @@ def annotated_chunk(
)
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]:
items = defaultdict(list)
source_lookup = {}
for source, chunk in chunks:
items[source].append(chunk)
items[source.id].append(chunk)
source_lookup[source.id] = source
return [
SearchResult(
id=source.id,
size=source.size or len(source.content),
size=source.size or source.content_length,
mime_type=source.mime_type or "text/plain",
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=source.display_contents,
content=source.content,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
)
for source, chunks in items.items()
for source_id, chunks in items.items()
for source in [source_lookup[source_id]]
]
@ -104,25 +169,31 @@ def query_chunks(
embedder: Callable,
min_score: float = 0.0,
limit: int = 10,
filters: dict[str, Any] | None = None,
) -> dict[str, list[qdrant_models.ScoredPoint]]:
if not upload_data:
if not upload_data or not allowed_modalities:
return {}
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
chunks = [chunk for chunk in upload_data if chunk.data]
if not chunks:
logger.error(f"No chunks to embed for {allowed_modalities}")
return {}
vector = embedder(chunks, input_type="query")[0]
logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}")
for c in chunks:
logger.error(f"Chunk: {c.data}")
vectors = embedder([c.data for c in chunks], input_type="query")
return {
collection: [
r
for vector in vectors
for r in qdrant.search_vectors(
client=client,
collection_name=collection,
query_vector=vector,
limit=limit,
filter_params=filters,
)
if r.score >= min_score
]
@ -130,6 +201,122 @@ def query_chunks(
}
async def search_bm25(
query: str,
modalities: list[str],
limit: int = 10,
filters: SearchFilters = SearchFilters(),
) -> list[tuple[SourceData, AnnotatedChunk]]:
with make_session() as db:
items_query = db.query(Chunk.id, Chunk.content).filter(
Chunk.collection_name.in_(modalities)
)
if source_ids := filters.get("source_ids"):
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
items = items_query.all()
item_ids = {
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
for item in items
}
corpus = [item.content.lower().strip() for item in items]
stemmer = Stemmer.Stemmer("english")
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
results, scores = retriever.retrieve(
query_tokens, k=min(limit, len(corpus)), corpus=corpus
)
item_scores = {
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
for doc, score in zip(results[0], scores[0])
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
results = []
for chunk in chunks:
# Prefetch all needed source data while in session
source = chunk.source
source_data = SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content=source.display_contents,
content_length=len(source.content) if source.content else 0,
)
annotated = AnnotatedChunk(
id=str(chunk.id),
score=item_scores[chunk.id],
metadata=source.as_payload(),
preview=None,
)
results.append((source_data, annotated))
return results
async def search_embeddings(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_score: float = 0.3,
filters: SearchFilters = SearchFilters(),
multimodal: bool = False,
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Search across knowledge base using text query and optional files.
Parameters:
- data: List of data to search in (e.g., text, images, files)
- previews: Whether to include previews in the search results
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
- limit: Maximum number of results
- min_score: Minimum score to include in the search results
- filters: Filters to apply to the search results
- multimodal: Whether to search in multimodal collections
"""
query_filters = {
"must": [
{"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}},
],
}
if tags := filters.get("tags"):
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
if observation_types := filters.get("observation_types"):
query_filters["must"] += [
{"key": "observation_type", "match": {"any": observation_types}}
]
client = qdrant.get_qdrant_client()
results = query_chunks(
client,
data,
modalities,
embedding.embed_text if not multimodal else embedding.embed_mixed,
min_score=min_score,
limit=limit,
filters=query_filters,
)
search_results = {k: results.get(k, []) for k in modalities}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
return [
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
@ -137,6 +324,7 @@ async def search(
limit: int = 10,
min_text_score: float = 0.3,
min_multimodal_score: float = 0.3,
filters: SearchFilters = {},
) -> list[SearchResult]:
"""
Search across knowledge base using text query and optional files.
@ -150,40 +338,45 @@ async def search(
Returns:
- List of search results sorted by score
"""
client = qdrant.get_qdrant_client()
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
text_results = query_chunks(
client,
data,
allowed_modalities & TEXT_COLLECTIONS,
embedding.embed_text,
min_score=min_text_score,
limit=limit,
)
multimodal_results = query_chunks(
client,
data,
allowed_modalities,
embedding.embed_mixed,
min_score=min_multimodal_score,
limit=limit,
)
search_results = {
k: text_results.get(k, []) + multimodal_results.get(k, [])
for k in allowed_modalities
}
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
logger.error(f"Found chunks: {chunks}")
results = group_chunks(
[
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
text_embeddings_results = with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & TEXT_COLLECTIONS,
limit,
min_text_score,
filters,
multimodal=False,
)
)
multimodal_embeddings_results = with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & MULTIMODAL_COLLECTIONS,
limit,
min_multimodal_score,
filters,
multimodal=True,
)
)
bm25_results = with_timeout(
search_bm25(
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
modalities,
limit=limit,
filters=filters,
)
)
results = await asyncio.gather(
text_embeddings_results,
multimodal_embeddings_results,
bm25_results,
return_exceptions=False,
)
results = group_chunks([c for r in results for c in r])
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)

View File

@ -1,6 +1,7 @@
import logging
from typing import Literal, NotRequired, TypedDict
from PIL import Image
from memory.common import settings
@ -14,73 +15,92 @@ Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
model: str
on_disk: NotRequired[bool]
shards: NotRequired[int]
text: bool
multimodal: bool
ALL_COLLECTIONS: dict[str, Collection] = {
"mail": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
"chat": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": True,
},
"git": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
"book": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
"blog": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": True,
"multimodal": True,
},
"forum": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": True,
"multimodal": True,
},
"text": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.TEXT_EMBEDDING_MODEL,
"text": True,
"multimodal": False,
},
# Multimodal
"photo": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": False,
"multimodal": True,
},
"comic": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": False,
"multimodal": True,
},
"doc": {
"dimension": 1024,
"distance": "Cosine",
"model": settings.MIXED_EMBEDDING_MODEL,
"text": False,
"multimodal": True,
},
# Observations
"semantic": {
"dimension": 1024,
"distance": "Cosine",
"text": True,
"multimodal": False,
},
"temporal": {
"dimension": 1024,
"distance": "Cosine",
"text": True,
"multimodal": False,
},
}
TEXT_COLLECTIONS = {
coll
for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.TEXT_EMBEDDING_MODEL
coll for coll, params in ALL_COLLECTIONS.items() if params.get("text")
}
MULTIMODAL_COLLECTIONS = {
coll
for coll, params in ALL_COLLECTIONS.items()
if params["model"] == settings.MIXED_EMBEDDING_MODEL
coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal")
}
TYPES = {
@ -108,5 +128,12 @@ def get_modality(mime_type: str) -> str:
return "unknown"
def collection_model(collection: str) -> str | None:
return ALL_COLLECTIONS.get(collection, {}).get("model")
def collection_model(
collection: str, text: str, images: list[Image.Image]
) -> str | None:
config = ALL_COLLECTIONS.get(collection, {})
if images and config.get("multimodal"):
return settings.MIXED_EMBEDDING_MODEL
if text and config.get("text"):
return settings.TEXT_EMBEDDING_MODEL
return "unknown"

View File

@ -2,7 +2,7 @@
Database utilities package.
"""
from memory.common.db.models import Base
from memory.common.db.models.base import Base
from memory.common.db.connection import (
get_engine,
get_session_factory,

View File

@ -0,0 +1,61 @@
from memory.common.db.models.base import Base
from memory.common.db.models.source_item import (
Chunk,
SourceItem,
clean_filename,
)
from memory.common.db.models.source_items import (
MailMessage,
EmailAttachment,
AgentObservation,
ChatMessage,
BlogPost,
Comic,
BookSection,
ForumPost,
GithubItem,
GitCommit,
Photo,
MiscDoc,
)
from memory.common.db.models.observations import (
ObservationContradiction,
ReactionPattern,
ObservationPattern,
BeliefCluster,
ConversationMetrics,
)
from memory.common.db.models.sources import (
Book,
ArticleFeed,
EmailAccount,
)
__all__ = [
"Base",
"Chunk",
"clean_filename",
"SourceItem",
"MailMessage",
"EmailAttachment",
"AgentObservation",
"ChatMessage",
"BlogPost",
"Comic",
"BookSection",
"ForumPost",
"GithubItem",
"GitCommit",
"Photo",
"MiscDoc",
# Observations
"ObservationContradiction",
"ReactionPattern",
"ObservationPattern",
"BeliefCluster",
"ConversationMetrics",
# Sources
"Book",
"ArticleFeed",
"EmailAccount",
]

View File

@ -0,0 +1,4 @@
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()

View File

@ -0,0 +1,171 @@
"""
Agent observation models for the epistemic sparring partner system.
"""
from sqlalchemy import (
ARRAY,
BigInteger,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
Text,
UUID,
func,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from memory.common.db.models.base import Base
class ObservationContradiction(Base):
"""
Tracks contradictions between observations.
Can be detected automatically or reported by agents.
"""
__tablename__ = "observation_contradiction"
id = Column(BigInteger, primary_key=True)
observation_1_id = Column(
BigInteger,
ForeignKey("agent_observation.id", ondelete="CASCADE"),
nullable=False,
)
observation_2_id = Column(
BigInteger,
ForeignKey("agent_observation.id", ondelete="CASCADE"),
nullable=False,
)
contradiction_type = Column(Text, nullable=False) # direct, implied, temporal
detected_at = Column(DateTime(timezone=True), server_default=func.now())
detection_method = Column(Text, nullable=False) # manual, automatic, agent-reported
resolution = Column(Text) # How it was resolved, if at all
observation_metadata = Column(JSONB)
# Relationships - use string references to avoid circular imports
observation_1 = relationship(
"AgentObservation",
foreign_keys=[observation_1_id],
back_populates="contradictions_as_first",
)
observation_2 = relationship(
"AgentObservation",
foreign_keys=[observation_2_id],
back_populates="contradictions_as_second",
)
__table_args__ = (
Index("obs_contra_obs1_idx", "observation_1_id"),
Index("obs_contra_obs2_idx", "observation_2_id"),
Index("obs_contra_type_idx", "contradiction_type"),
Index("obs_contra_method_idx", "detection_method"),
)
class ReactionPattern(Base):
"""
Tracks patterns in how the user reacts to certain types of observations or challenges.
"""
__tablename__ = "reaction_pattern"
id = Column(BigInteger, primary_key=True)
trigger_type = Column(
Text, nullable=False
) # What kind of observation triggers this
reaction_type = Column(Text, nullable=False) # How user typically responds
frequency = Column(Numeric(3, 2), nullable=False) # How often this pattern appears
first_observed = Column(DateTime(timezone=True), server_default=func.now())
last_observed = Column(DateTime(timezone=True), server_default=func.now())
example_observations = Column(
ARRAY(BigInteger)
) # IDs of observations showing this pattern
reaction_metadata = Column(JSONB)
__table_args__ = (
Index("reaction_trigger_idx", "trigger_type"),
Index("reaction_type_idx", "reaction_type"),
Index("reaction_frequency_idx", "frequency"),
)
class ObservationPattern(Base):
"""
Higher-level patterns detected across multiple observations.
"""
__tablename__ = "observation_pattern"
id = Column(BigInteger, primary_key=True)
pattern_type = Column(Text, nullable=False) # behavioral, cognitive, emotional
description = Column(Text, nullable=False)
supporting_observations = Column(
ARRAY(BigInteger), nullable=False
) # Observation IDs
exceptions = Column(ARRAY(BigInteger)) # Observations that don't fit
confidence = Column(Numeric(3, 2), nullable=False, default=0.7)
validity_start = Column(DateTime(timezone=True))
validity_end = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now())
pattern_metadata = Column(JSONB)
__table_args__ = (
Index("obs_pattern_type_idx", "pattern_type"),
Index("obs_pattern_confidence_idx", "confidence"),
Index("obs_pattern_validity_idx", "validity_start", "validity_end"),
)
class BeliefCluster(Base):
"""
Groups of related beliefs that support or depend on each other.
"""
__tablename__ = "belief_cluster"
id = Column(BigInteger, primary_key=True)
cluster_name = Column(Text, nullable=False)
core_beliefs = Column(ARRAY(Text), nullable=False)
peripheral_beliefs = Column(ARRAY(Text))
internal_consistency = Column(Numeric(3, 2)) # How well beliefs align
supporting_observations = Column(ARRAY(BigInteger)) # Observation IDs
created_at = Column(DateTime(timezone=True), server_default=func.now())
last_updated = Column(DateTime(timezone=True), server_default=func.now())
cluster_metadata = Column(JSONB)
__table_args__ = (
Index("belief_cluster_name_idx", "cluster_name"),
Index("belief_cluster_consistency_idx", "internal_consistency"),
)
class ConversationMetrics(Base):
"""
Tracks the effectiveness and depth of conversations.
"""
__tablename__ = "conversation_metrics"
id = Column(BigInteger, primary_key=True)
session_id = Column(UUID(as_uuid=True), nullable=False)
depth_score = Column(Numeric(3, 2)) # How deep the conversation went
breakthrough_count = Column(Integer, default=0)
challenge_acceptance = Column(Numeric(3, 2)) # How well challenges were received
new_insights = Column(Integer, default=0)
user_engagement = Column(Numeric(3, 2)) # Inferred engagement level
duration_minutes = Column(Integer)
observation_count = Column(Integer, default=0)
contradiction_count = Column(Integer, default=0)
created_at = Column(DateTime(timezone=True), server_default=func.now())
metrics_metadata = Column(JSONB)
__table_args__ = (
Index("conv_metrics_session_idx", "session_id", unique=True),
Index("conv_metrics_depth_idx", "depth_score"),
Index("conv_metrics_breakthrough_idx", "breakthrough_count"),
)

View File

@ -0,0 +1,294 @@
"""
Database models for the knowledge base system.
"""
import pathlib
import re
from typing import Any, Sequence, cast
import uuid
from PIL import Image
from sqlalchemy import (
ARRAY,
UUID,
BigInteger,
CheckConstraint,
Column,
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
event,
func,
)
from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.orm import Session, relationship
from memory.common import settings
import memory.common.extract as extract
import memory.common.collections as collections
import memory.common.chunker as chunker
import memory.common.summarizer as summarizer
from memory.common.db.models.base import Base
@event.listens_for(Session, "before_flush")
def handle_duplicate_sha256(session, flush_context, instances):
"""
Event listener that efficiently checks for duplicate sha256 values before flush
and removes items with duplicate sha256 from the session.
Uses a single query to identify all duplicates rather than querying for each item.
"""
# Find all SourceItem objects being added
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
if not new_items:
return
items = {}
for item in new_items:
try:
if (sha256 := item.sha256) is None:
continue
if sha256 in items:
session.expunge(item)
continue
items[sha256] = item
except (AttributeError, TypeError):
continue
if not new_items:
return
# Query database for existing items with these sha256 values in a single query
existing_sha256s = set(
row[0]
for row in session.query(SourceItem.sha256).filter(
SourceItem.sha256.in_(items.keys())
)
)
# Remove objects with duplicate sha256 values from the session
for sha256 in existing_sha256s:
if sha256 in items:
session.expunge(items[sha256])
def clean_filename(filename: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
for i, image in enumerate(images):
if not image.filename: # type: ignore
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
image.save(filename)
image.filename = str(filename) # type: ignore
return [image.filename for image in images] # type: ignore
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
return [chunk] + [
i
for i in images
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
]
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
final = {}
for m in metadata:
data = m.copy()
if tags := set(data.pop("tags", [])):
final["tags"] = tags | final.get("tags", set())
final |= data
return final
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
if not content.strip():
return []
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
summary, tags = summarizer.summarize(content)
full_text: extract.DataChunk = extract.DataChunk(
data=[content.strip(), *images], metadata={"tags": tags}
)
chunks: list[extract.DataChunk] = [full_text]
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks += [
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
for c in chunker.chunk_text(content)
]
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
return [c for c in chunks if c.data]
class Chunk(Base):
"""Stores content chunks with their vector embeddings."""
__tablename__ = "chunk"
# The ID is also used as the vector ID in the vector database
id = Column(
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
)
source_id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
)
file_paths = Column(
ARRAY(Text), nullable=True
) # Path to content if stored as a file
content = Column(Text) # Direct content storage
embedding_model = Column(Text)
collection_name = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now())
checked_at = Column(DateTime(timezone=True), server_default=func.now())
vector: list[float] = []
item_metadata: dict[str, Any] = {}
images: list[Image.Image] = []
# One of file_path or content must be populated
__table_args__ = (
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
Index("chunk_source_idx", "source_id"),
)
@property
def chunks(self) -> list[extract.MulitmodalChunk]:
chunks: list[extract.MulitmodalChunk] = []
if cast(str | None, self.content):
chunks = [cast(str, self.content)]
if self.images:
chunks += self.images
elif cast(Sequence[str] | None, self.file_paths):
chunks += [
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
]
return chunks
@property
def data(self) -> list[bytes | str | Image.Image]:
if self.file_paths is None:
return [cast(str, self.content)]
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
files = [path for path in paths if path.exists()]
items = []
for file_path in files:
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
if file_path.exists():
items.append(Image.open(file_path))
elif file_path.suffix == ".bin":
items.append(file_path.read_bytes())
else:
items.append(file_path.read_text())
return items
class SourceItem(Base):
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
__tablename__ = "source_item"
__allow_unmapped__ = True
id = Column(BigInteger, primary_key=True)
modality = Column(Text, nullable=False)
sha256 = Column(BYTEA, nullable=False, unique=True)
inserted_at = Column(
DateTime(timezone=True), server_default=func.now(), default=func.now()
)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
size = Column(Integer)
mime_type = Column(Text)
# Content is stored in the database if it's small enough and text
content = Column(Text)
# Otherwise the content is stored on disk
filename = Column(Text, nullable=True)
# Chunks relationship
embed_status = Column(Text, nullable=False, server_default="RAW")
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
# Discriminator column for SQLAlchemy inheritance
type = Column(String(50))
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
# Add table-level constraint and indexes
__table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
Index("source_modality_idx", "modality"),
Index("source_status_idx", "embed_status"),
Index("source_tags_idx", "tags", postgresql_using="gin"),
Index("source_filename_idx", "filename"),
)
@property
def vector_ids(self):
"""Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks]
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
chunks: list[extract.DataChunk] = []
content = cast(str | None, self.content)
if content:
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
summary, tags = summarizer.summarize(content)
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
mime_type = cast(str | None, self.mime_type)
if mime_type and mime_type.startswith("image/"):
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
return chunks
def _make_chunk(
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
) -> Chunk:
chunk_id = str(uuid.uuid4())
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
images = [c for c in data.data if isinstance(c, Image.Image)]
image_names = image_filenames(chunk_id, images)
modality = data.modality or cast(str, self.modality)
chunk = Chunk(
id=chunk_id,
source=self,
content=text or None,
images=images,
file_paths=image_names,
collection_name=modality,
embedding_model=collections.collection_model(modality, text, images),
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
)
return chunk
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data) for data in self._chunk_contents()]
def as_payload(self) -> dict:
return {
"source_id": self.id,
"tags": self.tags,
"size": self.size,
}
@property
def display_contents(self) -> str | dict | None:
return {
"tags": self.tags,
"size": self.size,
}

View File

@ -3,18 +3,14 @@ Database models for the knowledge base system.
"""
import pathlib
import re
import textwrap
from datetime import datetime
from typing import Any, Sequence, cast
import uuid
from PIL import Image
from sqlalchemy import (
ARRAY,
UUID,
BigInteger,
Boolean,
CheckConstraint,
Column,
DateTime,
@ -22,273 +18,24 @@ from sqlalchemy import (
Index,
Integer,
Numeric,
String,
Text,
event,
func,
)
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR, UUID
from sqlalchemy.orm import relationship
from memory.common import settings
import memory.common.extract as extract
import memory.common.collections as collections
import memory.common.chunker as chunker
import memory.common.summarizer as summarizer
import memory.common.formatters.observation as observation
Base = declarative_base()
@event.listens_for(Session, "before_flush")
def handle_duplicate_sha256(session, flush_context, instances):
"""
Event listener that efficiently checks for duplicate sha256 values before flush
and removes items with duplicate sha256 from the session.
Uses a single query to identify all duplicates rather than querying for each item.
"""
# Find all SourceItem objects being added
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
if not new_items:
return
items = {}
for item in new_items:
try:
if (sha256 := item.sha256) is None:
continue
if sha256 in items:
session.expunge(item)
continue
items[sha256] = item
except (AttributeError, TypeError):
continue
if not new_items:
return
# Query database for existing items with these sha256 values in a single query
existing_sha256s = set(
row[0]
for row in session.query(SourceItem.sha256).filter(
SourceItem.sha256.in_(items.keys())
)
)
# Remove objects with duplicate sha256 values from the session
for sha256 in existing_sha256s:
if sha256 in items:
session.expunge(items[sha256])
def clean_filename(filename: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
for i, image in enumerate(images):
if not image.filename: # type: ignore
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
image.save(filename)
image.filename = str(filename) # type: ignore
return [image.filename for image in images] # type: ignore
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
return [chunk] + [
i
for i in images
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
]
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
final = {}
for m in metadata:
if tags := set(m.pop("tags", [])):
final["tags"] = tags | final.get("tags", set())
final |= m
return final
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
if not content.strip():
return []
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
summary, tags = summarizer.summarize(content)
full_text: extract.DataChunk = extract.DataChunk(
data=[content.strip(), *images], metadata={"tags": tags}
)
chunks: list[extract.DataChunk] = [full_text]
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks += [
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
for c in chunker.chunk_text(content)
]
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
return [c for c in chunks if c.data]
class Chunk(Base):
"""Stores content chunks with their vector embeddings."""
__tablename__ = "chunk"
# The ID is also used as the vector ID in the vector database
id = Column(
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
)
source_id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
)
file_paths = Column(
ARRAY(Text), nullable=True
) # Path to content if stored as a file
content = Column(Text) # Direct content storage
embedding_model = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now())
checked_at = Column(DateTime(timezone=True), server_default=func.now())
vector: list[float] = []
item_metadata: dict[str, Any] = {}
images: list[Image.Image] = []
# One of file_path or content must be populated
__table_args__ = (
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
Index("chunk_source_idx", "source_id"),
)
@property
def chunks(self) -> list[extract.MulitmodalChunk]:
chunks: list[extract.MulitmodalChunk] = []
if cast(str | None, self.content):
chunks = [cast(str, self.content)]
if self.images:
chunks += self.images
elif cast(Sequence[str] | None, self.file_paths):
chunks += [
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
]
return chunks
@property
def data(self) -> list[bytes | str | Image.Image]:
if self.file_paths is None:
return [cast(str, self.content)]
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
files = [path for path in paths if path.exists()]
items = []
for file_path in files:
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
if file_path.exists():
items.append(Image.open(file_path))
elif file_path.suffix == ".bin":
items.append(file_path.read_bytes())
else:
items.append(file_path.read_text())
return items
class SourceItem(Base):
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
__tablename__ = "source_item"
__allow_unmapped__ = True
id = Column(BigInteger, primary_key=True)
modality = Column(Text, nullable=False)
sha256 = Column(BYTEA, nullable=False, unique=True)
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
size = Column(Integer)
mime_type = Column(Text)
# Content is stored in the database if it's small enough and text
content = Column(Text)
# Otherwise the content is stored on disk
filename = Column(Text, nullable=True)
# Chunks relationship
embed_status = Column(Text, nullable=False, server_default="RAW")
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
# Discriminator column for SQLAlchemy inheritance
type = Column(String(50))
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
# Add table-level constraint and indexes
__table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
Index("source_modality_idx", "modality"),
Index("source_status_idx", "embed_status"),
Index("source_tags_idx", "tags", postgresql_using="gin"),
Index("source_filename_idx", "filename"),
)
@property
def vector_ids(self):
"""Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks]
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
chunks: list[extract.DataChunk] = []
content = cast(str | None, self.content)
if content:
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
summary, tags = summarizer.summarize(content)
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
mime_type = cast(str | None, self.mime_type)
if mime_type and mime_type.startswith("image/"):
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
return chunks
def _make_chunk(
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
) -> Chunk:
chunk_id = str(uuid.uuid4())
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
images = [c for c in data.data if isinstance(c, Image.Image)]
image_names = image_filenames(chunk_id, images)
chunk = Chunk(
id=chunk_id,
source=self,
content=text or None,
images=images,
file_paths=image_names,
embedding_model=collections.collection_model(cast(str, self.modality)),
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
)
return chunk
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data) for data in self._chunk_contents()]
def as_payload(self) -> dict:
return {
"source_id": self.id,
"tags": self.tags,
"size": self.size,
}
@property
def display_contents(self) -> str | None:
return cast(str | None, self.content) or cast(str | None, self.filename)
from memory.common.db.models.source_item import (
SourceItem,
Chunk,
clean_filename,
merge_metadata,
chunk_mixed,
)
class MailMessage(SourceItem):
@ -528,51 +275,6 @@ class Comic(SourceItem):
return [extract.DataChunk(data=[image, description])]
class Book(Base):
"""Book-level metadata table"""
__tablename__ = "book"
id = Column(BigInteger, primary_key=True)
isbn = Column(Text, unique=True)
title = Column(Text, nullable=False)
author = Column(Text)
publisher = Column(Text)
published = Column(DateTime(timezone=True))
language = Column(Text)
edition = Column(Text)
series = Column(Text)
series_number = Column(Integer)
total_pages = Column(Integer)
file_path = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
# Metadata from ebook parser
book_metadata = Column(JSONB, name="metadata")
created_at = Column(DateTime(timezone=True), server_default=func.now())
__table_args__ = (
Index("book_isbn_idx", "isbn"),
Index("book_author_idx", "author"),
Index("book_title_idx", "title"),
)
def as_payload(self) -> dict:
return {
**super().as_payload(),
"isbn": self.isbn,
"title": self.title,
"author": self.author,
"publisher": self.publisher,
"published": self.published,
"language": self.language,
"edition": self.edition,
"series": self.series,
"series_number": self.series_number,
} | (cast(dict, self.book_metadata) or {})
class BookSection(SourceItem):
"""Individual sections/chapters of books"""
@ -797,58 +499,163 @@ class GithubItem(SourceItem):
)
class ArticleFeed(Base):
__tablename__ = "article_feeds"
class AgentObservation(SourceItem):
"""
Records observations made by AI agents about the user.
This is the primary data model for the epistemic sparring partner.
"""
id = Column(BigInteger, primary_key=True)
url = Column(Text, nullable=False, unique=True)
title = Column(Text)
description = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
check_interval = Column(
Integer, nullable=False, server_default="60", doc="Minutes between checks"
__tablename__ = "agent_observation"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
last_checked_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
session_id = Column(
UUID(as_uuid=True)
) # Groups observations from same conversation
observation_type = Column(
Text, nullable=False
) # belief, preference, pattern, contradiction, behavior
subject = Column(Text, nullable=False) # What/who the observation is about
confidence = Column(Numeric(3, 2), nullable=False, default=0.8) # 0.0-1.0
evidence = Column(JSONB) # Supporting context, quotes, etc.
agent_model = Column(Text, nullable=False) # Which AI model made this observation
# Relationships
contradictions_as_first = relationship(
"ObservationContradiction",
foreign_keys="ObservationContradiction.observation_1_id",
back_populates="observation_1",
cascade="all, delete-orphan",
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
contradictions_as_second = relationship(
"ObservationContradiction",
foreign_keys="ObservationContradiction.observation_2_id",
back_populates="observation_2",
cascade="all, delete-orphan",
)
# Add indexes
__mapper_args__ = {
"polymorphic_identity": "agent_observation",
}
__table_args__ = (
Index("article_feeds_active_idx", "active", "last_checked_at"),
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
Index("agent_obs_session_idx", "session_id"),
Index("agent_obs_type_idx", "observation_type"),
Index("agent_obs_subject_idx", "subject"),
Index("agent_obs_confidence_idx", "confidence"),
Index("agent_obs_model_idx", "agent_model"),
)
def __init__(self, **kwargs):
if not kwargs.get("modality"):
kwargs["modality"] = "observation"
super().__init__(**kwargs)
class EmailAccount(Base):
__tablename__ = "email_accounts"
def as_payload(self) -> dict:
payload = {
**super().as_payload(),
"observation_type": self.observation_type,
"subject": self.subject,
"confidence": float(cast(Any, self.confidence)),
"evidence": self.evidence,
"agent_model": self.agent_model,
}
if self.session_id is not None:
payload["session_id"] = str(self.session_id)
return payload
id = Column(BigInteger, primary_key=True)
name = Column(Text, nullable=False)
email_address = Column(Text, nullable=False, unique=True)
imap_server = Column(Text, nullable=False)
imap_port = Column(Integer, nullable=False, server_default="993")
username = Column(Text, nullable=False)
password = Column(Text, nullable=False)
use_ssl = Column(Boolean, nullable=False, server_default="true")
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
last_sync_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
@property
def all_contradictions(self):
"""Get all contradictions involving this observation."""
return self.contradictions_as_first + self.contradictions_as_second
# Add indexes
__table_args__ = (
Index("email_accounts_address_idx", "email_address", unique=True),
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)
@property
def display_contents(self) -> dict:
return {
"subject": self.subject,
"content": self.content,
"observation_type": self.observation_type,
"evidence": self.evidence,
"confidence": self.confidence,
"agent_model": self.agent_model,
"tags": self.tags,
}
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
"""
Generate multiple chunks for different embedding dimensions.
Each chunk goes to a different Qdrant collection for specialized search.
"""
# 1. Semantic chunk - standard content representation
semantic_text = observation.generate_semantic_text(
cast(str, self.subject),
cast(str, self.observation_type),
cast(str, self.content),
cast(observation.Evidence, self.evidence),
)
semantic_chunk = extract.DataChunk(
data=[semantic_text],
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
modality="semantic",
)
# 2. Temporal chunk - time-aware representation
temporal_text = observation.generate_temporal_text(
cast(str, self.subject),
cast(str, self.content),
cast(float, self.confidence),
cast(datetime, self.inserted_at),
)
temporal_chunk = extract.DataChunk(
data=[temporal_text],
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
modality="temporal",
)
others = [
self._make_chunk(
extract.DataChunk(
data=[i],
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
modality="semantic",
)
)
for i in [
self.content,
self.subject,
self.observation_type,
self.evidence.get("quote", ""),
]
if i
]
# TODO: Add more embedding dimensions here:
# 3. Epistemic chunk - belief structure focused
# epistemic_text = self._generate_epistemic_text()
# chunks.append(extract.DataChunk(
# data=[epistemic_text],
# metadata={**base_metadata, "embedding_type": "epistemic"},
# collection_name="observations_epistemic"
# ))
#
# 4. Emotional chunk - emotional context focused
# emotional_text = self._generate_emotional_text()
# chunks.append(extract.DataChunk(
# data=[emotional_text],
# metadata={**base_metadata, "embedding_type": "emotional"},
# collection_name="observations_emotional"
# ))
#
# 5. Relational chunk - connection patterns focused
# relational_text = self._generate_relational_text()
# chunks.append(extract.DataChunk(
# data=[relational_text],
# metadata={**base_metadata, "embedding_type": "relational"},
# collection_name="observations_relational"
# ))
return [
self._make_chunk(semantic_chunk),
self._make_chunk(temporal_chunk),
] + others

View File

@ -0,0 +1,122 @@
"""
Database models for the knowledge base system.
"""
from typing import cast
from sqlalchemy import (
ARRAY,
BigInteger,
Boolean,
Column,
DateTime,
Index,
Integer,
Text,
func,
)
from sqlalchemy.dialects.postgresql import JSONB
from memory.common.db.models.base import Base
class Book(Base):
"""Book-level metadata table"""
__tablename__ = "book"
id = Column(BigInteger, primary_key=True)
isbn = Column(Text, unique=True)
title = Column(Text, nullable=False)
author = Column(Text)
publisher = Column(Text)
published = Column(DateTime(timezone=True))
language = Column(Text)
edition = Column(Text)
series = Column(Text)
series_number = Column(Integer)
total_pages = Column(Integer)
file_path = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
# Metadata from ebook parser
book_metadata = Column(JSONB, name="metadata")
created_at = Column(DateTime(timezone=True), server_default=func.now())
__table_args__ = (
Index("book_isbn_idx", "isbn"),
Index("book_author_idx", "author"),
Index("book_title_idx", "title"),
)
def as_payload(self) -> dict:
return {
**super().as_payload(),
"isbn": self.isbn,
"title": self.title,
"author": self.author,
"publisher": self.publisher,
"published": self.published,
"language": self.language,
"edition": self.edition,
"series": self.series,
"series_number": self.series_number,
} | (cast(dict, self.book_metadata) or {})
class ArticleFeed(Base):
__tablename__ = "article_feeds"
id = Column(BigInteger, primary_key=True)
url = Column(Text, nullable=False, unique=True)
title = Column(Text)
description = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
check_interval = Column(
Integer, nullable=False, server_default="60", doc="Minutes between checks"
)
last_checked_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# Add indexes
__table_args__ = (
Index("article_feeds_active_idx", "active", "last_checked_at"),
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
)
class EmailAccount(Base):
__tablename__ = "email_accounts"
id = Column(BigInteger, primary_key=True)
name = Column(Text, nullable=False)
email_address = Column(Text, nullable=False, unique=True)
imap_server = Column(Text, nullable=False)
imap_port = Column(Integer, nullable=False, server_default="993")
username = Column(Text, nullable=False)
password = Column(Text, nullable=False)
use_ssl = Column(Boolean, nullable=False, server_default="true")
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
last_sync_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# Add indexes
__table_args__ = (
Index("email_accounts_address_idx", "email_address", unique=True),
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)

View File

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
def embed_chunks(
chunks: list[str] | list[list[extract.MulitmodalChunk]],
chunks: list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
) -> list[Vector]:
@ -24,52 +24,50 @@ def embed_chunks(
vo = voyageai.Client() # type: ignore
if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed(
chunks, # type: ignore
chunks,
model=model,
input_type=input_type,
).embeddings
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks]
return cast(
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
)
def break_chunk(
chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS
) -> list[extract.MulitmodalChunk]:
result = []
for c in chunk:
if isinstance(c, str):
result += chunk_text(c, chunk_size, OVERLAP_TOKENS)
else:
result.append(chunk)
return result
def embed_text(
texts: list[str],
chunks: list[list[extract.MulitmodalChunk]],
model: str = settings.TEXT_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]:
chunks = [
c
for text in texts
if isinstance(text, str)
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
if c.strip()
]
if not chunks:
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks]
if not any(chunked_chunks):
return []
try:
return embed_chunks(chunks, model, input_type)
except voyageai.error.InvalidRequestError as e: # type: ignore
logger.error(f"Error embedding text: {e}")
logger.debug(f"Text: {texts}")
raise
return embed_chunks(chunked_chunks, model, input_type)
def embed_mixed(
items: list[extract.MulitmodalChunk],
items: list[list[extract.MulitmodalChunk]],
model: str = settings.MIXED_EMBEDDING_MODEL,
input_type: Literal["document", "query"] = "document",
chunk_size: int = DEFAULT_CHUNK_TOKENS,
) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
if isinstance(item, str):
return [
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
]
return [item]
chunks = [c for item in items for c in to_chunks(item)]
return embed_chunks([chunks], model, input_type)
chunked_chunks = [break_chunk(item, chunk_size) for item in items]
return embed_chunks(chunked_chunks, model, input_type)
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
@ -87,6 +85,9 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
def embed_source_item(item: SourceItem) -> list[Chunk]:
chunks = list(item.data_chunks())
logger.error(
f"Embedding source item: {item.id} - {[(c.embedding_model, c.collection_name, c.chunks) for c in chunks]}"
)
if not chunks:
return []

View File

@ -21,6 +21,7 @@ class DataChunk:
data: Sequence[MulitmodalChunk]
metadata: dict[str, Any] = field(default_factory=dict)
mime_type: str = "text/plain"
modality: str | None = None
@contextmanager

View File

@ -0,0 +1,87 @@
from datetime import datetime
from typing import TypedDict
class Evidence(TypedDict):
quote: str
context: str
def generate_semantic_text(
subject: str, observation_type: str, content: str, evidence: Evidence | None = None
) -> str:
"""Generate text optimized for semantic similarity search."""
parts = [
f"Subject: {subject}",
f"Type: {observation_type}",
f"Observation: {content}",
]
if not evidence or not isinstance(evidence, dict):
return " | ".join(parts)
if "quote" in evidence:
parts.append(f"Quote: {evidence['quote']}")
if "context" in evidence:
parts.append(f"Context: {evidence['context']}")
return " | ".join(parts)
def generate_temporal_text(
subject: str,
content: str,
confidence: float,
created_at: datetime,
) -> str:
"""Generate text with temporal context for time-pattern search."""
# Add temporal markers
time_of_day = created_at.strftime("%H:%M")
day_of_week = created_at.strftime("%A")
# Categorize time periods
hour = created_at.hour
if 5 <= hour < 12:
time_period = "morning"
elif 12 <= hour < 17:
time_period = "afternoon"
elif 17 <= hour < 22:
time_period = "evening"
else:
time_period = "late_night"
parts = [
f"Time: {time_of_day} on {day_of_week} ({time_period})",
f"Subject: {subject}",
f"Observation: {content}",
]
if confidence:
parts.append(f"Confidence: {confidence}")
return " | ".join(parts)
# TODO: Add more embedding dimensions here:
# 3. Epistemic chunk - belief structure focused
# epistemic_text = self._generate_epistemic_text()
# chunks.append(extract.DataChunk(
# data=[epistemic_text],
# metadata={**base_metadata, "embedding_type": "epistemic"},
# collection_name="observations_epistemic"
# ))
#
# 4. Emotional chunk - emotional context focused
# emotional_text = self._generate_emotional_text()
# chunks.append(extract.DataChunk(
# data=[emotional_text],
# metadata={**base_metadata, "embedding_type": "emotional"},
# collection_name="observations_emotional"
# ))
#
# 5. Relational chunk - connection patterns focused
# relational_text = self._generate_relational_text()
# chunks.append(extract.DataChunk(
# data=[relational_text],
# metadata={**base_metadata, "embedding_type": "relational"},
# collection_name="observations_relational"
# ))

View File

@ -9,6 +9,8 @@ logger = logging.getLogger(__name__)
TAGS_PROMPT = """
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
Return your response as JSON with this format:
{{
"summary": "{summary}",
@ -23,6 +25,8 @@ SUMMARY_PROMPT = """
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
Also provide 3-5 relevant tags that capture the main topics or themes.
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
Return your response as JSON with this format:
{{
"summary": "your summary here",

View File

@ -1,62 +0,0 @@
import argparse
import logging
from typing import Any
from fastapi import UploadFile
import httpx
from mcp.server.fastmcp import FastMCP
SERVER = "http://localhost:8000"
logger = logging.getLogger(__name__)
mcp = FastMCP("memory")
async def make_request(
path: str,
method: str,
data: dict | None = None,
json: dict | None = None,
files: list[UploadFile] | None = None,
) -> httpx.Response:
async with httpx.AsyncClient() as client:
return await client.request(
method,
f"{SERVER}/{path}",
data=data,
json=json,
files=files, # type: ignore
)
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
return await make_request(path, "POST", data=data)
@mcp.tool()
async def search(
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
) -> list[dict[str, Any]]:
logger.error(f"Searching for {query}")
resp = await post_data(
"search",
{
"query": query,
"previews": previews,
"modalities": modalities,
"limit": limit,
},
)
return resp.json()
if __name__ == "__main__":
# Initialize and run the server
args = argparse.ArgumentParser()
args.add_argument("--server", type=str)
args = args.parse_args()
SERVER = args.server
mcp.run(transport=args.transport)

View File

@ -6,13 +6,14 @@ the complete workflow: existence checking, content hashing, embedding generation
vector storage, and result tracking.
"""
from collections import defaultdict
import hashlib
import traceback
import logging
from typing import Any, Callable, Iterable, Sequence, cast
from memory.common import embedding, qdrant
from memory.common.db.models import SourceItem
from memory.common.db.models import SourceItem, Chunk
logger = logging.getLogger(__name__)
@ -103,7 +104,19 @@ def embed_source_item(source_item: SourceItem) -> int:
return 0
def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
def by_collection(chunks: Sequence[Chunk]) -> dict[str, dict[str, Any]]:
collections: dict[str, dict[str, list[Any]]] = defaultdict(
lambda: defaultdict(list)
)
for chunk in chunks:
collection = collections[cast(str, chunk.collection_name)]
collection["ids"].append(chunk.id)
collection["vectors"].append(chunk.vector)
collection["payloads"].append(chunk.item_metadata)
return collections
def push_to_qdrant(source_items: Sequence[SourceItem]):
"""
Push embeddings to Qdrant vector database.
@ -135,17 +148,16 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
return
try:
vector_ids = [str(chunk.id) for chunk in all_chunks]
vectors = [chunk.vector for chunk in all_chunks]
payloads = [chunk.item_metadata for chunk in all_chunks]
qdrant.upsert_vectors(
client=qdrant.get_qdrant_client(),
collection_name=collection_name,
ids=vector_ids,
vectors=vectors,
payloads=payloads,
)
client = qdrant.get_qdrant_client()
collections = by_collection(all_chunks)
for collection_name, collection in collections.items():
qdrant.upsert_vectors(
client=client,
collection_name=collection_name,
ids=collection["ids"],
vectors=collection["vectors"],
payloads=collection["payloads"],
)
for item in items_to_process:
item.embed_status = "STORED" # type: ignore
@ -222,7 +234,7 @@ def process_content_item(item: SourceItem, session) -> dict[str, Any]:
return create_task_result(item, status, content_length=getattr(item, "size", 0))
try:
push_to_qdrant([item], cast(str, item.modality))
push_to_qdrant([item])
status = "processed"
item.embed_status = "STORED" # type: ignore
logger.info(

View File

@ -185,7 +185,7 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
f"Embedded section: {section.section_title} - {section.content[:100]}"
)
logger.info("Pushing to Qdrant")
push_to_qdrant(all_sections, "book")
push_to_qdrant(all_sections)
logger.info("Committing session")
session.commit()

View File

@ -1,4 +1,3 @@
from memory.common.db.models import SourceItem
from sqlalchemy.orm import Session
from unittest.mock import patch, Mock
from typing import cast
@ -6,17 +5,20 @@ import pytest
from PIL import Image
from datetime import datetime
from memory.common import settings, chunker, extract
from memory.common.db.models import (
from memory.common.db.models.sources import Book
from memory.common.db.models.source_items import (
Chunk,
clean_filename,
image_filenames,
add_pics,
MailMessage,
EmailAttachment,
BookSection,
BlogPost,
Book,
)
from memory.common.db.models.source_item import (
SourceItem,
image_filenames,
add_pics,
merge_metadata,
clean_filename,
)
@ -585,288 +587,6 @@ def test_chunk_constraint_validation(
assert chunk.id is not None
@pytest.mark.parametrize(
"modality,expected_modality",
[
(None, "email"), # Default case
("custom", "custom"), # Override case
],
)
def test_mail_message_modality(modality, expected_modality):
"""Test MailMessage modality setting"""
kwargs = {"sha256": b"test", "content": "test"}
if modality is not None:
kwargs["modality"] = modality
mail_message = MailMessage(**kwargs)
# The __init__ method should set the correct modality
assert hasattr(mail_message, "modality")
@pytest.mark.parametrize(
"sender,folder,expected_path",
[
("user@example.com", "INBOX", "user_example_com/INBOX"),
("user+tag@example.com", "Sent Items", "user_tag_example_com/Sent_Items"),
("user@domain.co.uk", None, "user_domain_co_uk/INBOX"),
("user@domain.co.uk", "", "user_domain_co_uk/INBOX"),
],
)
def test_mail_message_attachments_path(sender, folder, expected_path):
"""Test MailMessage.attachments_path property"""
mail_message = MailMessage(
sha256=b"test", content="test", sender=sender, folder=folder
)
result = mail_message.attachments_path
assert str(result) == f"{settings.FILE_STORAGE_DIR}/emails/{expected_path}"
@pytest.mark.parametrize(
"filename,expected",
[
("document.pdf", "document.pdf"),
("file with spaces.txt", "file_with_spaces.txt"),
("file@#$%^&*().doc", "file.doc"),
("no-extension", "no_extension"),
("multiple.dots.in.name.txt", "multiple_dots_in_name.txt"),
],
)
def test_mail_message_safe_filename(tmp_path, filename, expected):
"""Test MailMessage.safe_filename method"""
mail_message = MailMessage(
sha256=b"test", content="test", sender="user@example.com", folder="INBOX"
)
expected = settings.FILE_STORAGE_DIR / f"emails/user_example_com/INBOX/{expected}"
assert mail_message.safe_filename(filename) == expected
@pytest.mark.parametrize(
"sent_at,expected_date",
[
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
(None, None),
],
)
def test_mail_message_as_payload(sent_at, expected_date):
"""Test MailMessage.as_payload method"""
mail_message = MailMessage(
sha256=b"test",
content="test",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient1@example.com", "recipient2@example.com"],
folder="INBOX",
sent_at=sent_at,
tags=["tag1", "tag2"],
size=1024,
)
# Manually set id for testing
object.__setattr__(mail_message, "id", 123)
payload = mail_message.as_payload()
expected = {
"source_id": 123,
"size": 1024,
"message_id": "<test@example.com>",
"subject": "Test Subject",
"sender": "sender@example.com",
"recipients": ["recipient1@example.com", "recipient2@example.com"],
"folder": "INBOX",
"tags": [
"tag1",
"tag2",
"sender@example.com",
"recipient1@example.com",
"recipient2@example.com",
],
"date": expected_date,
}
assert payload == expected
def test_mail_message_parsed_content():
"""Test MailMessage.parsed_content property with actual email parsing"""
# Use a simple email format that the parser can handle
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
result = mail_message.parsed_content
# Just test that it returns a dict-like object
assert isinstance(result, dict)
assert "body" in result
def test_mail_message_body_property():
"""Test MailMessage.body property with actual email parsing"""
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
assert mail_message.body == "Test Body Content"
def test_mail_message_display_contents():
"""Test MailMessage.display_contents property with actual email parsing"""
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
expected = (
"\nSubject: Test Subject\nFrom: \nTo: \nDate: \nBody: \nTest Body Content\n"
)
assert mail_message.display_contents == expected
@pytest.mark.parametrize(
"created_at,expected_date",
[
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
(None, None),
],
)
def test_email_attachment_as_payload(created_at, expected_date):
"""Test EmailAttachment.as_payload method"""
attachment = EmailAttachment(
sha256=b"test",
filename="document.pdf",
mime_type="application/pdf",
size=1024,
mail_message_id=123,
created_at=created_at,
tags=["pdf", "document"],
)
# Manually set id for testing
object.__setattr__(attachment, "id", 456)
payload = attachment.as_payload()
expected = {
"source_id": 456,
"filename": "document.pdf",
"content_type": "application/pdf",
"size": 1024,
"created_at": expected_date,
"mail_message_id": 123,
"tags": ["pdf", "document"],
}
assert payload == expected
@pytest.mark.parametrize(
"has_filename,content_source,expected_content",
[
(True, "file", b"test file content"),
(False, "content", "attachment content"),
],
)
@patch("memory.common.extract.extract_data_chunks")
def test_email_attachment_data_chunks(
mock_extract, has_filename, content_source, expected_content, tmp_path
):
"""Test EmailAttachment.data_chunks method"""
from memory.common.extract import DataChunk
mock_extract.return_value = [
DataChunk(data=["extracted text"], metadata={"source": content_source})
]
if has_filename:
# Create a test file
test_file = tmp_path / "test.txt"
test_file.write_bytes(b"test file content")
attachment = EmailAttachment(
sha256=b"test",
filename=str(test_file),
mime_type="text/plain",
mail_message_id=123,
)
else:
attachment = EmailAttachment(
sha256=b"test",
content="attachment content",
filename=None,
mime_type="text/plain",
mail_message_id=123,
)
# Mock _make_chunk to return a simple chunk
mock_chunk = Mock()
with patch.object(attachment, "_make_chunk", return_value=mock_chunk) as mock_make:
result = attachment.data_chunks({"extra": "metadata"})
# Verify the method calls
mock_extract.assert_called_once_with("text/plain", expected_content)
mock_make.assert_called_once_with(
extract.DataChunk(data=["extracted text"], metadata={"source": content_source}),
{"extra": "metadata"},
)
assert result == [mock_chunk]
def test_email_attachment_cascade_delete(db_session: Session):
"""Test that EmailAttachment is deleted when MailMessage is deleted"""
mail_message = MailMessage(
sha256=b"test_email",
content="test email",
message_id="<test@example.com>",
subject="Test",
sender="sender@example.com",
recipients=["recipient@example.com"],
folder="INBOX",
)
db_session.add(mail_message)
db_session.commit()
attachment = EmailAttachment(
sha256=b"test_attachment",
content="attachment content",
mail_message=mail_message,
filename="test.txt",
mime_type="text/plain",
size=100,
modality="attachment", # Set modality explicitly
)
db_session.add(attachment)
db_session.commit()
attachment_id = attachment.id
# Delete the mail message
db_session.delete(mail_message)
db_session.commit()
# Verify the attachment was also deleted
deleted_attachment = (
db_session.query(EmailAttachment).filter_by(id=attachment_id).first()
)
assert deleted_attachment is None
def test_subclass_deletion_cascades_to_source_item(db_session: Session):
mail_message = MailMessage(
sha256=b"test_email_cascade",
@ -926,173 +646,3 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session):
# Verify both the MailMessage and SourceItem records are deleted
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
@pytest.mark.parametrize(
"pages,expected_chunks",
[
# No pages
([], []),
# Single page
(["Page 1 content"], [("Page 1 content", {"type": "page"})]),
# Multiple pages
(
["Page 1", "Page 2", "Page 3"],
[
(
"Page 1\n\nPage 2\n\nPage 3",
{"type": "section", "tags": {"tag1", "tag2"}},
),
("test", {"type": "summary", "tags": {"tag1", "tag2"}}),
],
),
# Empty/whitespace pages filtered out
(["", " ", "Page 3"], [("Page 3", {"type": "page"})]),
# All empty - no chunks created
(["", " ", " "], []),
],
)
def test_book_section_data_chunks(pages, expected_chunks):
"""Test BookSection.data_chunks with various page combinations"""
content = "\n\n".join(pages).strip()
book_section = BookSection(
sha256=b"test_section",
content=content,
modality="book",
book_id=1,
start_page=10,
end_page=10 + len(pages),
pages=pages,
book=Book(id=1, title="Test Book", author="Test Author"),
)
chunks = book_section.data_chunks()
expected = [
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
]
assert [(c.content, c.item_metadata) for c in chunks] == expected
for c in chunks:
assert cast(list, c.file_paths) == []
@pytest.mark.parametrize(
"content,expected",
[
("", []),
(
"Short content",
[
extract.DataChunk(
data=["Short content"], metadata={"tags": ["tag1", "tag2"]}
)
],
),
(
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
[
extract.DataChunk(
data=[
"This is a very long piece of content that should be chunked into multiple pieces when processed."
],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["This is a very long piece of content that"],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["should be chunked into multiple pieces when"],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["processed."],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["test"],
metadata={"tags": ["tag1", "tag2"]},
),
],
),
],
)
def test_blog_post_chunk_contents(content, expected, default_chunk_size):
default_chunk_size(10)
blog_post = BlogPost(
sha256=b"test_blog",
content=content,
modality="blog",
url="https://example.com/post",
images=[],
)
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
assert blog_post._chunk_contents() == expected
def test_blog_post_chunk_contents_with_images(tmp_path):
"""Test BlogPost._chunk_contents with images"""
# Create test image files
img1_path = tmp_path / "img1.jpg"
img2_path = tmp_path / "img2.jpg"
for img_path in [img1_path, img2_path]:
img = Image.new("RGB", (10, 10), color="red")
img.save(img_path)
blog_post = BlogPost(
sha256=b"test_blog",
content="Content with images",
modality="blog",
url="https://example.com/post",
images=[str(img1_path), str(img2_path)],
)
result = blog_post._chunk_contents()
result = [
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
for c in result
]
assert result == [
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
]
def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chunk_size):
default_chunk_size(10)
img1_path = tmp_path / "img1.jpg"
img2_path = tmp_path / "img2.jpg"
for img_path in [img1_path, img2_path]:
img = Image.new("RGB", (10, 10), color="red")
img.save(img_path)
blog_post = BlogPost(
sha256=b"test_blog",
content=f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
modality="blog",
url="https://example.com/post",
images=[str(img1_path), str(img2_path)],
)
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
result = blog_post._chunk_contents()
result = [
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
for c in result
]
assert result == [
[
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
img1_path.as_posix(),
img2_path.as_posix(),
],
[
f"First picture is here: {img1_path.as_posix()}",
img1_path.as_posix(),
],
[
f"Second picture is here: {img2_path.as_posix()}",
img2_path.as_posix(),
],
["test"],
]

View File

@ -0,0 +1,607 @@
from sqlalchemy.orm import Session
from unittest.mock import patch, Mock
from typing import cast
import pytest
from PIL import Image
from datetime import datetime
import uuid
from memory.common import settings, chunker, extract
from memory.common.db.models.sources import Book
from memory.common.db.models.source_items import (
MailMessage,
EmailAttachment,
BookSection,
BlogPost,
AgentObservation,
)
from memory.common.db.models.source_item import merge_metadata
@pytest.fixture
def default_chunk_size():
chunk_length = chunker.DEFAULT_CHUNK_TOKENS
real_chunker = chunker.chunk_text
def chunk_text(text: str, max_tokens: int = 0):
max_tokens = max_tokens or chunk_length
return real_chunker(text, max_tokens=max_tokens)
def set_size(new_size: int):
nonlocal chunk_length
chunk_length = new_size
with patch.object(chunker, "chunk_text", chunk_text):
yield set_size
@pytest.mark.parametrize(
"modality,expected_modality",
[
(None, "email"), # Default case
("custom", "custom"), # Override case
],
)
def test_mail_message_modality(modality, expected_modality):
"""Test MailMessage modality setting"""
kwargs = {"sha256": b"test", "content": "test"}
if modality is not None:
kwargs["modality"] = modality
mail_message = MailMessage(**kwargs)
# The __init__ method should set the correct modality
assert hasattr(mail_message, "modality")
@pytest.mark.parametrize(
"sender,folder,expected_path",
[
("user@example.com", "INBOX", "user_example_com/INBOX"),
("user+tag@example.com", "Sent Items", "user_tag_example_com/Sent_Items"),
("user@domain.co.uk", None, "user_domain_co_uk/INBOX"),
("user@domain.co.uk", "", "user_domain_co_uk/INBOX"),
],
)
def test_mail_message_attachments_path(sender, folder, expected_path):
"""Test MailMessage.attachments_path property"""
mail_message = MailMessage(
sha256=b"test", content="test", sender=sender, folder=folder
)
result = mail_message.attachments_path
assert str(result) == f"{settings.FILE_STORAGE_DIR}/emails/{expected_path}"
@pytest.mark.parametrize(
"filename,expected",
[
("document.pdf", "document.pdf"),
("file with spaces.txt", "file_with_spaces.txt"),
("file@#$%^&*().doc", "file.doc"),
("no-extension", "no_extension"),
("multiple.dots.in.name.txt", "multiple_dots_in_name.txt"),
],
)
def test_mail_message_safe_filename(tmp_path, filename, expected):
"""Test MailMessage.safe_filename method"""
mail_message = MailMessage(
sha256=b"test", content="test", sender="user@example.com", folder="INBOX"
)
expected = settings.FILE_STORAGE_DIR / f"emails/user_example_com/INBOX/{expected}"
assert mail_message.safe_filename(filename) == expected
@pytest.mark.parametrize(
"sent_at,expected_date",
[
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
(None, None),
],
)
def test_mail_message_as_payload(sent_at, expected_date):
"""Test MailMessage.as_payload method"""
mail_message = MailMessage(
sha256=b"test",
content="test",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient1@example.com", "recipient2@example.com"],
folder="INBOX",
sent_at=sent_at,
tags=["tag1", "tag2"],
size=1024,
)
# Manually set id for testing
object.__setattr__(mail_message, "id", 123)
payload = mail_message.as_payload()
expected = {
"source_id": 123,
"size": 1024,
"message_id": "<test@example.com>",
"subject": "Test Subject",
"sender": "sender@example.com",
"recipients": ["recipient1@example.com", "recipient2@example.com"],
"folder": "INBOX",
"tags": [
"tag1",
"tag2",
"sender@example.com",
"recipient1@example.com",
"recipient2@example.com",
],
"date": expected_date,
}
assert payload == expected
def test_mail_message_parsed_content():
"""Test MailMessage.parsed_content property with actual email parsing"""
# Use a simple email format that the parser can handle
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
result = mail_message.parsed_content
# Just test that it returns a dict-like object
assert isinstance(result, dict)
assert "body" in result
def test_mail_message_body_property():
"""Test MailMessage.body property with actual email parsing"""
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
assert mail_message.body == "Test Body Content"
def test_mail_message_display_contents():
"""Test MailMessage.display_contents property with actual email parsing"""
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
expected = (
"\nSubject: Test Subject\nFrom: \nTo: \nDate: \nBody: \nTest Body Content\n"
)
assert mail_message.display_contents == expected
@pytest.mark.parametrize(
"created_at,expected_date",
[
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
(None, None),
],
)
def test_email_attachment_as_payload(created_at, expected_date):
"""Test EmailAttachment.as_payload method"""
attachment = EmailAttachment(
sha256=b"test",
filename="document.pdf",
mime_type="application/pdf",
size=1024,
mail_message_id=123,
created_at=created_at,
tags=["pdf", "document"],
)
# Manually set id for testing
object.__setattr__(attachment, "id", 456)
payload = attachment.as_payload()
expected = {
"source_id": 456,
"filename": "document.pdf",
"content_type": "application/pdf",
"size": 1024,
"created_at": expected_date,
"mail_message_id": 123,
"tags": ["pdf", "document"],
}
assert payload == expected
@pytest.mark.parametrize(
"has_filename,content_source,expected_content",
[
(True, "file", b"test file content"),
(False, "content", "attachment content"),
],
)
@patch("memory.common.extract.extract_data_chunks")
def test_email_attachment_data_chunks(
mock_extract, has_filename, content_source, expected_content, tmp_path
):
"""Test EmailAttachment.data_chunks method"""
from memory.common.extract import DataChunk
mock_extract.return_value = [
DataChunk(data=["extracted text"], metadata={"source": content_source})
]
if has_filename:
# Create a test file
test_file = tmp_path / "test.txt"
test_file.write_bytes(b"test file content")
attachment = EmailAttachment(
sha256=b"test",
filename=str(test_file),
mime_type="text/plain",
mail_message_id=123,
)
else:
attachment = EmailAttachment(
sha256=b"test",
content="attachment content",
filename=None,
mime_type="text/plain",
mail_message_id=123,
)
# Mock _make_chunk to return a simple chunk
mock_chunk = Mock()
with patch.object(attachment, "_make_chunk", return_value=mock_chunk) as mock_make:
result = attachment.data_chunks({"extra": "metadata"})
# Verify the method calls
mock_extract.assert_called_once_with("text/plain", expected_content)
mock_make.assert_called_once_with(
extract.DataChunk(data=["extracted text"], metadata={"source": content_source}),
{"extra": "metadata"},
)
assert result == [mock_chunk]
def test_email_attachment_cascade_delete(db_session: Session):
"""Test that EmailAttachment is deleted when MailMessage is deleted"""
mail_message = MailMessage(
sha256=b"test_email",
content="test email",
message_id="<test@example.com>",
subject="Test",
sender="sender@example.com",
recipients=["recipient@example.com"],
folder="INBOX",
)
db_session.add(mail_message)
db_session.commit()
attachment = EmailAttachment(
sha256=b"test_attachment",
content="attachment content",
mail_message=mail_message,
filename="test.txt",
mime_type="text/plain",
size=100,
modality="attachment", # Set modality explicitly
)
db_session.add(attachment)
db_session.commit()
attachment_id = attachment.id
# Delete the mail message
db_session.delete(mail_message)
db_session.commit()
# Verify the attachment was also deleted
deleted_attachment = (
db_session.query(EmailAttachment).filter_by(id=attachment_id).first()
)
assert deleted_attachment is None
@pytest.mark.parametrize(
"pages,expected_chunks",
[
# No pages
([], []),
# Single page
(["Page 1 content"], [("Page 1 content", {"type": "page"})]),
# Multiple pages
(
["Page 1", "Page 2", "Page 3"],
[
(
"Page 1\n\nPage 2\n\nPage 3",
{"type": "section", "tags": {"tag1", "tag2"}},
),
("test", {"type": "summary", "tags": {"tag1", "tag2"}}),
],
),
# Empty/whitespace pages filtered out
(["", " ", "Page 3"], [("Page 3", {"type": "page"})]),
# All empty - no chunks created
(["", " ", " "], []),
],
)
def test_book_section_data_chunks(pages, expected_chunks):
"""Test BookSection.data_chunks with various page combinations"""
content = "\n\n".join(pages).strip()
book_section = BookSection(
sha256=b"test_section",
content=content,
modality="book",
book_id=1,
start_page=10,
end_page=10 + len(pages),
pages=pages,
book=Book(id=1, title="Test Book", author="Test Author"),
)
chunks = book_section.data_chunks()
expected = [
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
]
assert [(c.content, c.item_metadata) for c in chunks] == expected
for c in chunks:
assert cast(list, c.file_paths) == []
@pytest.mark.parametrize(
"content,expected",
[
("", []),
(
"Short content",
[
extract.DataChunk(
data=["Short content"], metadata={"tags": ["tag1", "tag2"]}
)
],
),
(
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
[
extract.DataChunk(
data=[
"This is a very long piece of content that should be chunked into multiple pieces when processed."
],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["This is a very long piece of content that"],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["should be chunked into multiple pieces when"],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["processed."],
metadata={"tags": ["tag1", "tag2"]},
),
extract.DataChunk(
data=["test"],
metadata={"tags": ["tag1", "tag2"]},
),
],
),
],
)
def test_blog_post_chunk_contents(content, expected, default_chunk_size):
default_chunk_size(10)
blog_post = BlogPost(
sha256=b"test_blog",
content=content,
modality="blog",
url="https://example.com/post",
images=[],
)
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
assert blog_post._chunk_contents() == expected
def test_blog_post_chunk_contents_with_images(tmp_path):
"""Test BlogPost._chunk_contents with images"""
# Create test image files
img1_path = tmp_path / "img1.jpg"
img2_path = tmp_path / "img2.jpg"
for img_path in [img1_path, img2_path]:
img = Image.new("RGB", (10, 10), color="red")
img.save(img_path)
blog_post = BlogPost(
sha256=b"test_blog",
content="Content with images",
modality="blog",
url="https://example.com/post",
images=[str(img1_path), str(img2_path)],
)
result = blog_post._chunk_contents()
result = [
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
for c in result
]
assert result == [
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
]
def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chunk_size):
default_chunk_size(10)
img1_path = tmp_path / "img1.jpg"
img2_path = tmp_path / "img2.jpg"
for img_path in [img1_path, img2_path]:
img = Image.new("RGB", (10, 10), color="red")
img.save(img_path)
blog_post = BlogPost(
sha256=b"test_blog",
content=f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
modality="blog",
url="https://example.com/post",
images=[str(img1_path), str(img2_path)],
)
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
result = blog_post._chunk_contents()
result = [
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
for c in result
]
assert result == [
[
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
img1_path.as_posix(),
img2_path.as_posix(),
],
[
f"First picture is here: {img1_path.as_posix()}",
img1_path.as_posix(),
],
[
f"Second picture is here: {img2_path.as_posix()}",
img2_path.as_posix(),
],
["test"],
]
@pytest.mark.parametrize(
"metadata,expected_semantic_metadata,expected_temporal_metadata,observation_tags",
[
(
{},
{"embedding_type": "semantic"},
{"embedding_type": "temporal"},
[],
),
(
{"extra_key": "extra_value"},
{"extra_key": "extra_value", "embedding_type": "semantic"},
{"extra_key": "extra_value", "embedding_type": "temporal"},
[],
),
(
{"tags": ["existing_tag"], "source": "test"},
{"tags": {"existing_tag"}, "source": "test", "embedding_type": "semantic"},
{"tags": {"existing_tag"}, "source": "test", "embedding_type": "temporal"},
[],
),
],
)
def test_agent_observation_data_chunks(
metadata, expected_semantic_metadata, expected_temporal_metadata, observation_tags
):
"""Test AgentObservation.data_chunks generates correct chunks with proper metadata"""
observation = AgentObservation(
sha256=b"test_obs",
content="User prefers Python over JavaScript",
subject="programming preferences",
observation_type="preference",
confidence=0.9,
evidence={
"quote": "I really like Python",
"context": "discussion about languages",
},
agent_model="claude-3.5-sonnet",
session_id=uuid.uuid4(),
tags=observation_tags,
)
# Set inserted_at using object.__setattr__ to bypass SQLAlchemy restrictions
object.__setattr__(observation, "inserted_at", datetime(2023, 1, 1, 12, 0, 0))
result = observation.data_chunks(metadata)
# Verify chunks
assert len(result) == 2
semantic_chunk = result[0]
expected_semantic_text = "Subject: programming preferences | Type: preference | Observation: User prefers Python over JavaScript | Quote: I really like Python | Context: discussion about languages"
assert semantic_chunk.data == [expected_semantic_text]
assert semantic_chunk.metadata == expected_semantic_metadata
assert semantic_chunk.collection_name == "semantic"
temporal_chunk = result[1]
expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript | Confidence: 0.9"
assert temporal_chunk.data == [expected_temporal_text]
assert temporal_chunk.metadata == expected_temporal_metadata
assert temporal_chunk.collection_name == "temporal"
def test_agent_observation_data_chunks_with_none_values():
"""Test AgentObservation.data_chunks handles None values correctly"""
observation = AgentObservation(
sha256=b"test_obs",
content="Content",
subject="subject",
observation_type="belief",
confidence=0.7,
evidence=None,
agent_model="gpt-4",
session_id=None,
)
object.__setattr__(observation, "inserted_at", datetime(2023, 2, 15, 9, 30, 0))
result = observation.data_chunks()
assert len(result) == 2
assert result[0].collection_name == "semantic"
assert result[1].collection_name == "temporal"
# Verify content with None evidence
semantic_text = "Subject: subject | Type: belief | Observation: Content"
assert result[0].data == [semantic_text]
temporal_text = "Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content | Confidence: 0.7"
assert result[1].data == [temporal_text]
def test_agent_observation_data_chunks_merge_metadata_behavior():
"""Test that merge_metadata works correctly in data_chunks"""
observation = AgentObservation(
sha256=b"test",
content="test",
subject="test",
observation_type="test",
confidence=0.8,
evidence={},
agent_model="test",
tags=["base_tag"], # Set base tags so they appear in both chunks
)
object.__setattr__(observation, "inserted_at", datetime.now())
# Test that metadata merging preserves original values and adds new ones
input_metadata = {"existing": "value", "tags": ["tag1"]}
result = observation.data_chunks(input_metadata)
semantic_metadata = result[0].metadata
temporal_metadata = result[1].metadata
# Both should have the existing metadata plus embedding_type
assert semantic_metadata["existing"] == "value"
assert semantic_metadata["tags"] == {"tag1"} # Merged tags
assert semantic_metadata["embedding_type"] == "semantic"
assert temporal_metadata["existing"] == "value"
assert temporal_metadata["tags"] == {"tag1"} # Merged tags
assert temporal_metadata["embedding_type"] == "temporal"

View File

@ -0,0 +1,220 @@
import pytest
from datetime import datetime
from typing import Any
from memory.common.formatters.observation import (
Evidence,
generate_semantic_text,
generate_temporal_text,
)
def test_generate_semantic_text_basic_functionality():
evidence: Evidence = {"quote": "test quote", "context": "test context"}
result = generate_semantic_text(
subject="test_subject",
observation_type="test_type",
content="test_content",
evidence=evidence,
)
assert (
result
== "Subject: test_subject | Type: test_type | Observation: test_content | Quote: test quote | Context: test context"
)
@pytest.mark.parametrize(
"evidence,expected_suffix",
[
({"quote": "test quote"}, " | Quote: test quote"),
({"context": "test context"}, " | Context: test context"),
({}, ""),
],
)
def test_generate_semantic_text_partial_evidence(
evidence: dict[str, str], expected_suffix: str
):
result = generate_semantic_text(
subject="subject",
observation_type="type",
content="content",
evidence=evidence, # type: ignore
)
expected = f"Subject: subject | Type: type | Observation: content{expected_suffix}"
assert result == expected
def test_generate_semantic_text_none_evidence():
result = generate_semantic_text(
subject="subject",
observation_type="type",
content="content",
evidence=None, # type: ignore
)
assert result == "Subject: subject | Type: type | Observation: content"
@pytest.mark.parametrize(
"invalid_evidence",
[
"string",
123,
["list"],
True,
],
)
def test_generate_semantic_text_invalid_evidence_types(invalid_evidence: Any):
result = generate_semantic_text(
subject="subject",
observation_type="type",
content="content",
evidence=invalid_evidence, # type: ignore
)
assert result == "Subject: subject | Type: type | Observation: content"
def test_generate_semantic_text_empty_strings():
evidence = {"quote": "", "context": ""}
result = generate_semantic_text(
subject="",
observation_type="",
content="",
evidence=evidence, # type: ignore
)
assert result == "Subject: | Type: | Observation: | Quote: | Context: "
def test_generate_semantic_text_special_characters():
evidence: Evidence = {
"quote": "Quote with | pipe and | symbols",
"context": "Context with special chars: @#$%",
}
result = generate_semantic_text(
subject="Subject with | pipe",
observation_type="Type with | pipe",
content="Content with | pipe",
evidence=evidence,
)
expected = "Subject: Subject with | pipe | Type: Type with | pipe | Observation: Content with | pipe | Quote: Quote with | pipe and | symbols | Context: Context with special chars: @#$%"
assert result == expected
@pytest.mark.parametrize(
"hour,expected_period",
[
(5, "morning"),
(6, "morning"),
(11, "morning"),
(12, "afternoon"),
(13, "afternoon"),
(16, "afternoon"),
(17, "evening"),
(18, "evening"),
(21, "evening"),
(22, "late_night"),
(23, "late_night"),
(0, "late_night"),
(1, "late_night"),
(4, "late_night"),
],
)
def test_generate_temporal_text_time_periods(hour: int, expected_period: str):
test_date = datetime(2024, 1, 15, hour, 30) # Monday
result = generate_temporal_text(
subject="test_subject",
content="test_content",
confidence=0.8,
created_at=test_date,
)
time_str = test_date.strftime("%H:%M")
expected = f"Time: {time_str} on Monday ({expected_period}) | Subject: test_subject | Observation: test_content | Confidence: 0.8"
assert result == expected
@pytest.mark.parametrize(
"weekday,day_name",
[
(0, "Monday"),
(1, "Tuesday"),
(2, "Wednesday"),
(3, "Thursday"),
(4, "Friday"),
(5, "Saturday"),
(6, "Sunday"),
],
)
def test_generate_temporal_text_days_of_week(weekday: int, day_name: str):
test_date = datetime(2024, 1, 15 + weekday, 10, 30)
result = generate_temporal_text(
subject="subject", content="content", confidence=0.5, created_at=test_date
)
assert f"on {day_name}" in result
@pytest.mark.parametrize("confidence", [0.0, 0.1, 0.5, 0.99, 1.0])
def test_generate_temporal_text_confidence_values(confidence: float):
test_date = datetime(2024, 1, 15, 10, 30)
result = generate_temporal_text(
subject="subject",
content="content",
confidence=confidence,
created_at=test_date,
)
assert f"Confidence: {confidence}" in result
@pytest.mark.parametrize(
"test_date,expected_period",
[
(datetime(2024, 1, 15, 5, 0), "morning"), # Start of morning
(datetime(2024, 1, 15, 11, 59), "morning"), # End of morning
(datetime(2024, 1, 15, 12, 0), "afternoon"), # Start of afternoon
(datetime(2024, 1, 15, 16, 59), "afternoon"), # End of afternoon
(datetime(2024, 1, 15, 17, 0), "evening"), # Start of evening
(datetime(2024, 1, 15, 21, 59), "evening"), # End of evening
(datetime(2024, 1, 15, 22, 0), "late_night"), # Start of late_night
(datetime(2024, 1, 15, 4, 59), "late_night"), # End of late_night
],
)
def test_generate_temporal_text_boundary_cases(
test_date: datetime, expected_period: str
):
result = generate_temporal_text(
subject="subject", content="content", confidence=0.8, created_at=test_date
)
assert f"({expected_period})" in result
def test_generate_temporal_text_complete_format():
test_date = datetime(2024, 3, 22, 14, 45) # Friday afternoon
result = generate_temporal_text(
subject="Important observation",
content="User showed strong preference for X",
confidence=0.95,
created_at=test_date,
)
expected = "Time: 14:45 on Friday (afternoon) | Subject: Important observation | Observation: User showed strong preference for X | Confidence: 0.95"
assert result == expected
def test_generate_temporal_text_empty_strings():
test_date = datetime(2024, 1, 15, 10, 30)
result = generate_temporal_text(
subject="", content="", confidence=0.0, created_at=test_date
)
assert (
result
== "Time: 10:30 on Monday (morning) | Subject: | Observation: | Confidence: 0.0"
)
def test_generate_temporal_text_special_characters():
test_date = datetime(2024, 1, 15, 15, 20)
result = generate_temporal_text(
subject="Subject with | pipe",
content="Content with | pipe and @#$ symbols",
confidence=0.75,
created_at=test_date,
)
expected = "Time: 15:20 on Monday (afternoon) | Subject: Subject with | pipe | Observation: Content with | pipe and @#$ symbols | Confidence: 0.75"
assert result == expected

View File

@ -18,6 +18,7 @@ from memory.workers.tasks.content_processing import (
process_content_item,
push_to_qdrant,
safe_task_execution,
by_collection,
)
@ -71,6 +72,7 @@ def mock_chunk():
chunk.id = "00000000-0000-0000-0000-000000000001"
chunk.vector = [0.1] * 1024
chunk.item_metadata = {"source_id": 1, "tags": ["test"]}
chunk.collection_name = "mail"
return chunk
@ -242,17 +244,19 @@ def test_push_to_qdrant_success(qdrant):
mock_chunk1.id = "00000000-0000-0000-0000-000000000001"
mock_chunk1.vector = [0.1] * 1024
mock_chunk1.item_metadata = {"source_id": 1, "tags": ["test"]}
mock_chunk1.collection_name = "mail"
mock_chunk2 = MagicMock()
mock_chunk2.id = "00000000-0000-0000-0000-000000000002"
mock_chunk2.vector = [0.2] * 1024
mock_chunk2.item_metadata = {"source_id": 2, "tags": ["test"]}
mock_chunk2.collection_name = "mail"
# Assign chunks directly (bypassing SQLAlchemy relationship)
item1.chunks = [mock_chunk1]
item2.chunks = [mock_chunk2]
push_to_qdrant([item1, item2], "mail")
push_to_qdrant([item1, item2])
assert str(item1.embed_status) == "STORED"
assert str(item2.embed_status) == "STORED"
@ -294,6 +298,7 @@ def test_push_to_qdrant_no_processing(
mock_chunk.id = f"00000000-0000-0000-0000-00000000000{suffix}"
mock_chunk.vector = [0.1] * 1024
mock_chunk.item_metadata = {"source_id": int(suffix), "tags": ["test"]}
mock_chunk.collection_name = "mail"
item.chunks = [mock_chunk]
else:
item.chunks = []
@ -302,7 +307,7 @@ def test_push_to_qdrant_no_processing(
item1 = create_item("1", item1_status, item1_has_chunks)
item2 = create_item("2", item2_status, item2_has_chunks)
push_to_qdrant([item1, item2], "mail")
push_to_qdrant([item1, item2])
assert str(item1.embed_status) == expected_item1_status
assert str(item2.embed_status) == expected_item2_status
@ -317,7 +322,7 @@ def test_push_to_qdrant_exception(sample_mail_message, mock_chunk):
side_effect=Exception("Qdrant error"),
):
with pytest.raises(Exception, match="Qdrant error"):
push_to_qdrant([sample_mail_message], "mail")
push_to_qdrant([sample_mail_message])
assert str(sample_mail_message.embed_status) == "FAILED"
@ -418,6 +423,196 @@ def test_create_task_result_no_title():
assert result["chunks_count"] == 0
def test_by_collection_empty_chunks():
result = by_collection([])
assert result == {}
def test_by_collection_single_chunk():
chunk = Chunk(
id="00000000-0000-0000-0000-000000000001",
content="test content",
embedding_model="test-model",
vector=[0.1, 0.2, 0.3],
item_metadata={"source_id": 1, "tags": ["test"]},
collection_name="test_collection",
)
result = by_collection([chunk])
assert len(result) == 1
assert "test_collection" in result
assert result["test_collection"]["ids"] == ["00000000-0000-0000-0000-000000000001"]
assert result["test_collection"]["vectors"] == [[0.1, 0.2, 0.3]]
assert result["test_collection"]["payloads"] == [{"source_id": 1, "tags": ["test"]}]
def test_by_collection_multiple_chunks_same_collection():
chunks = [
Chunk(
id="00000000-0000-0000-0000-000000000001",
content="test content 1",
embedding_model="test-model",
vector=[0.1, 0.2],
item_metadata={"source_id": 1},
collection_name="collection_a",
),
Chunk(
id="00000000-0000-0000-0000-000000000002",
content="test content 2",
embedding_model="test-model",
vector=[0.3, 0.4],
item_metadata={"source_id": 2},
collection_name="collection_a",
),
]
result = by_collection(chunks)
assert len(result) == 1
assert "collection_a" in result
assert result["collection_a"]["ids"] == [
"00000000-0000-0000-0000-000000000001",
"00000000-0000-0000-0000-000000000002",
]
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.3, 0.4]]
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 2}]
def test_by_collection_multiple_chunks_different_collections():
chunks = [
Chunk(
id="00000000-0000-0000-0000-000000000001",
content="test content 1",
embedding_model="test-model",
vector=[0.1, 0.2],
item_metadata={"source_id": 1},
collection_name="collection_a",
),
Chunk(
id="00000000-0000-0000-0000-000000000002",
content="test content 2",
embedding_model="test-model",
vector=[0.3, 0.4],
item_metadata={"source_id": 2},
collection_name="collection_b",
),
Chunk(
id="00000000-0000-0000-0000-000000000003",
content="test content 3",
embedding_model="test-model",
vector=[0.5, 0.6],
item_metadata={"source_id": 3},
collection_name="collection_a",
),
]
result = by_collection(chunks)
assert len(result) == 2
assert "collection_a" in result
assert "collection_b" in result
# Check collection_a
assert result["collection_a"]["ids"] == [
"00000000-0000-0000-0000-000000000001",
"00000000-0000-0000-0000-000000000003",
]
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.5, 0.6]]
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 3}]
# Check collection_b
assert result["collection_b"]["ids"] == ["00000000-0000-0000-0000-000000000002"]
assert result["collection_b"]["vectors"] == [[0.3, 0.4]]
assert result["collection_b"]["payloads"] == [{"source_id": 2}]
@pytest.mark.parametrize(
"collection_names,expected_collections",
[
(["col1", "col1", "col1"], 1),
(["col1", "col2", "col3"], 3),
(["col1", "col2", "col1", "col2"], 2),
(["single"], 1),
],
)
def test_by_collection_various_groupings(collection_names, expected_collections):
chunks = [
Chunk(
id=f"00000000-0000-0000-0000-00000000000{i}",
content=f"test content {i}",
embedding_model="test-model",
vector=[float(i)],
item_metadata={"index": i},
collection_name=collection_name,
)
for i, collection_name in enumerate(collection_names, 1)
]
result = by_collection(chunks)
assert len(result) == expected_collections
# Verify all chunks are accounted for
total_chunks = sum(len(coll["ids"]) for coll in result.values())
assert total_chunks == len(chunks)
def test_by_collection_with_none_values():
chunks = [
Chunk(
id="00000000-0000-0000-0000-000000000001",
content="test content",
embedding_model="test-model",
vector=None, # None vector
item_metadata=None, # None metadata
collection_name="test_collection",
),
Chunk(
id="00000000-0000-0000-0000-000000000002",
content="test content 2",
embedding_model="test-model",
vector=[0.1, 0.2],
item_metadata={"key": "value"},
collection_name="test_collection",
),
]
result = by_collection(chunks)
assert len(result) == 1
assert "test_collection" in result
assert result["test_collection"]["ids"] == [
"00000000-0000-0000-0000-000000000001",
"00000000-0000-0000-0000-000000000002",
]
assert result["test_collection"]["vectors"] == [None, [0.1, 0.2]]
assert result["test_collection"]["payloads"] == [None, {"key": "value"}]
def test_by_collection_preserves_order():
chunks = []
for i in range(5):
chunks.append(
Chunk(
id=f"00000000-0000-0000-0000-00000000000{i}",
content=f"test content {i}",
embedding_model="test-model",
vector=[float(i)],
item_metadata={"order": i},
collection_name="ordered_collection",
)
)
result = by_collection(chunks)
assert len(result) == 1
assert result["ordered_collection"]["ids"] == [
f"00000000-0000-0000-0000-00000000000{i}" for i in range(5)
]
assert result["ordered_collection"]["vectors"] == [[float(i)] for i in range(5)]
assert result["ordered_collection"]["payloads"] == [{"order": i} for i in range(5)]
@pytest.mark.parametrize(
"embedding_return,qdrant_error,expected_status,expected_embed_status",
[
@ -459,6 +654,7 @@ def test_process_content_item(
embedding_model="test-model",
vector=[0.1] * 1024,
item_metadata={"source_id": 1, "tags": ["test"]},
collection_name="mail",
)
mock_chunks = [real_chunk]
else: # empty

View File

@ -2,6 +2,7 @@
import uuid
from datetime import datetime, timedelta
from unittest.mock import patch, call
from typing import cast
import pytest
from PIL import Image
@ -350,6 +351,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
id=chunk_id,
source=item,
content=f"Test chunk content {i}",
collection_name=item.modality,
embedding_model="test-model",
)
for i, chunk_id in enumerate(chunk_ids)
@ -358,7 +360,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
db_session.commit()
# Add vectors to Qdrant
modality = "mail" if item_type == "MailMessage" else "blog"
modality = cast(str, item.modality)
qd.ensure_collection_exists(qdrant, modality, 1024)
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
@ -375,6 +377,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
id=str(uuid.uuid4()),
content="New chunk content 1",
embedding_model="test-model",
collection_name=modality,
vector=[0.1] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
@ -382,6 +385,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
id=str(uuid.uuid4()),
content="New chunk content 2",
embedding_model="test-model",
collection_name=modality,
vector=[0.2] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
@ -449,6 +453,7 @@ def test_reingest_item_no_chunks(db_session, qdrant):
id=str(uuid.uuid4()),
content="New chunk content",
embedding_model="test-model",
collection_name=item.modality,
vector=[0.1] * 1024,
item_metadata={"source_id": item.id, "tags": ["test"]},
),
@ -538,6 +543,7 @@ def test_reingest_empty_source_items_success(db_session, item_type):
source=item_with_chunks,
content="Test chunk content",
embedding_model="test-model",
collection_name=item_with_chunks.modality,
)
db_session.add(chunk)
db_session.commit()