mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-30 06:36:07 +02:00
Compare commits
3 Commits
e505f9b53c
...
b10a1fb130
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b10a1fb130 | ||
![]() |
1dd93929c1 | ||
![]() |
004bd39987 |
@ -12,6 +12,32 @@ from alembic import context
|
||||
from memory.common import settings
|
||||
from memory.common.db.models import Base
|
||||
|
||||
# Import all models to ensure they're registered with Base.metadata
|
||||
from memory.common.db.models import (
|
||||
SourceItem,
|
||||
Chunk,
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
ChatMessage,
|
||||
BlogPost,
|
||||
Comic,
|
||||
BookSection,
|
||||
ForumPost,
|
||||
GithubItem,
|
||||
GitCommit,
|
||||
Photo,
|
||||
MiscDoc,
|
||||
AgentObservation,
|
||||
ObservationContradiction,
|
||||
ReactionPattern,
|
||||
ObservationPattern,
|
||||
BeliefCluster,
|
||||
ConversationMetrics,
|
||||
Book,
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
)
|
||||
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
279
db/migrations/versions/20250531_154947_add_observation_models.py
Normal file
279
db/migrations/versions/20250531_154947_add_observation_models.py
Normal file
@ -0,0 +1,279 @@
|
||||
"""Add observation models
|
||||
|
||||
Revision ID: 6554eb260176
|
||||
Revises: 2524646f56f6
|
||||
Create Date: 2025-05-31 15:49:47.579256
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6554eb260176"
|
||||
down_revision: Union[str, None] = "2524646f56f6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"belief_cluster",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("cluster_name", sa.Text(), nullable=False),
|
||||
sa.Column("core_beliefs", sa.ARRAY(sa.Text()), nullable=False),
|
||||
sa.Column("peripheral_beliefs", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"internal_consistency", sa.Numeric(precision=3, scale=2), nullable=True
|
||||
),
|
||||
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"last_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"cluster_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"belief_cluster_consistency_idx",
|
||||
"belief_cluster",
|
||||
["internal_consistency"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"belief_cluster_name_idx", "belief_cluster", ["cluster_name"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"conversation_metrics",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("session_id", sa.UUID(), nullable=False),
|
||||
sa.Column("depth_score", sa.Numeric(precision=3, scale=2), nullable=True),
|
||||
sa.Column("breakthrough_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"challenge_acceptance", sa.Numeric(precision=3, scale=2), nullable=True
|
||||
),
|
||||
sa.Column("new_insights", sa.Integer(), nullable=True),
|
||||
sa.Column("user_engagement", sa.Numeric(precision=3, scale=2), nullable=True),
|
||||
sa.Column("duration_minutes", sa.Integer(), nullable=True),
|
||||
sa.Column("observation_count", sa.Integer(), nullable=True),
|
||||
sa.Column("contradiction_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"metrics_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"conv_metrics_breakthrough_idx",
|
||||
"conversation_metrics",
|
||||
["breakthrough_count"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"conv_metrics_depth_idx", "conversation_metrics", ["depth_score"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"conv_metrics_session_idx", "conversation_metrics", ["session_id"], unique=True
|
||||
)
|
||||
op.create_table(
|
||||
"observation_pattern",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("pattern_type", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=False),
|
||||
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=False),
|
||||
sa.Column("exceptions", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||
sa.Column("validity_start", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("validity_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"pattern_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"obs_pattern_confidence_idx",
|
||||
"observation_pattern",
|
||||
["confidence"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"obs_pattern_type_idx", "observation_pattern", ["pattern_type"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"obs_pattern_validity_idx",
|
||||
"observation_pattern",
|
||||
["validity_start", "validity_end"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"reaction_pattern",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("trigger_type", sa.Text(), nullable=False),
|
||||
sa.Column("reaction_type", sa.Text(), nullable=False),
|
||||
sa.Column("frequency", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||
sa.Column(
|
||||
"first_observed",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"last_observed",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("example_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||
sa.Column(
|
||||
"reaction_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"reaction_frequency_idx", "reaction_pattern", ["frequency"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"reaction_trigger_idx", "reaction_pattern", ["trigger_type"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"reaction_type_idx", "reaction_pattern", ["reaction_type"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"agent_observation",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("session_id", sa.UUID(), nullable=True),
|
||||
sa.Column("observation_type", sa.Text(), nullable=False),
|
||||
sa.Column("subject", sa.Text(), nullable=False),
|
||||
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||
sa.Column("evidence", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("agent_model", sa.Text(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_confidence_idx", "agent_observation", ["confidence"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_model_idx", "agent_observation", ["agent_model"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_session_idx", "agent_observation", ["session_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_subject_idx", "agent_observation", ["subject"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"agent_obs_type_idx", "agent_observation", ["observation_type"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"observation_contradiction",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("observation_1_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("observation_2_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("contradiction_type", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"detected_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("detection_method", sa.Text(), nullable=False),
|
||||
sa.Column("resolution", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"observation_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["observation_1_id"], ["agent_observation.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["observation_2_id"], ["agent_observation.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"obs_contra_method_idx",
|
||||
"observation_contradiction",
|
||||
["detection_method"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"obs_contra_obs1_idx",
|
||||
"observation_contradiction",
|
||||
["observation_1_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"obs_contra_obs2_idx",
|
||||
"observation_contradiction",
|
||||
["observation_2_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"obs_contra_type_idx",
|
||||
"observation_contradiction",
|
||||
["contradiction_type"],
|
||||
unique=False,
|
||||
)
|
||||
op.add_column("chunk", sa.Column("collection_name", sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chunk", "collection_name")
|
||||
op.drop_index("obs_contra_type_idx", table_name="observation_contradiction")
|
||||
op.drop_index("obs_contra_obs2_idx", table_name="observation_contradiction")
|
||||
op.drop_index("obs_contra_obs1_idx", table_name="observation_contradiction")
|
||||
op.drop_index("obs_contra_method_idx", table_name="observation_contradiction")
|
||||
op.drop_table("observation_contradiction")
|
||||
op.drop_index("agent_obs_type_idx", table_name="agent_observation")
|
||||
op.drop_index("agent_obs_subject_idx", table_name="agent_observation")
|
||||
op.drop_index("agent_obs_session_idx", table_name="agent_observation")
|
||||
op.drop_index("agent_obs_model_idx", table_name="agent_observation")
|
||||
op.drop_index("agent_obs_confidence_idx", table_name="agent_observation")
|
||||
op.drop_table("agent_observation")
|
||||
op.drop_index("reaction_type_idx", table_name="reaction_pattern")
|
||||
op.drop_index("reaction_trigger_idx", table_name="reaction_pattern")
|
||||
op.drop_index("reaction_frequency_idx", table_name="reaction_pattern")
|
||||
op.drop_table("reaction_pattern")
|
||||
op.drop_index("obs_pattern_validity_idx", table_name="observation_pattern")
|
||||
op.drop_index("obs_pattern_type_idx", table_name="observation_pattern")
|
||||
op.drop_index("obs_pattern_confidence_idx", table_name="observation_pattern")
|
||||
op.drop_table("observation_pattern")
|
||||
op.drop_index("conv_metrics_session_idx", table_name="conversation_metrics")
|
||||
op.drop_index("conv_metrics_depth_idx", table_name="conversation_metrics")
|
||||
op.drop_index("conv_metrics_breakthrough_idx", table_name="conversation_metrics")
|
||||
op.drop_table("conversation_metrics")
|
||||
op.drop_index("belief_cluster_name_idx", table_name="belief_cluster")
|
||||
op.drop_index("belief_cluster_consistency_idx", table_name="belief_cluster")
|
||||
op.drop_table("belief_cluster")
|
@ -3,3 +3,4 @@ uvicorn==0.29.0
|
||||
python-jose==3.3.0
|
||||
python-multipart==0.0.9
|
||||
sqladmin
|
||||
mcp==1.9.2
|
@ -5,4 +5,6 @@ alembic==1.13.1
|
||||
dotenv==0.9.9
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
||||
anthropic==0.18.1
|
||||
anthropic==0.18.1
|
||||
|
||||
bm25s[full]==0.2.13
|
654
src/memory/api/MCP/tools.py
Normal file
654
src/memory/api/MCP/tools.py
Normal file
@ -0,0 +1,654 @@
|
||||
"""
|
||||
MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from hashlib import sha256
|
||||
import logging
|
||||
import uuid
|
||||
from typing import cast
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from memory.common.db.models.source_item import SourceItem
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy import func, cast as sql_cast, Text
|
||||
from memory.common.db.connection import make_session
|
||||
|
||||
from memory.common import extract
|
||||
from memory.common.db.models import AgentObservation
|
||||
from memory.api.search import search, SearchFilters
|
||||
from memory.common.formatters import observation
|
||||
from memory.workers.tasks.content_processing import process_content_item
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create MCP server instance
|
||||
mcp = FastMCP("memory", stateless=True)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""
|
||||
Get all unique tags used across the entire knowledge base.
|
||||
|
||||
Purpose:
|
||||
This tool retrieves all tags that have been used in the system, both from
|
||||
AI observations (created with 'observe') and other content. Use it to
|
||||
understand the tag taxonomy, ensure consistency, or discover related topics.
|
||||
|
||||
When to use:
|
||||
- Before creating new observations, to use consistent tag naming
|
||||
- To explore what topics/contexts have been tracked
|
||||
- To build tag filters for search operations
|
||||
- To understand the user's areas of interest
|
||||
- For tag autocomplete or suggestion features
|
||||
|
||||
Returns:
|
||||
Sorted list of all unique tags in the system. Tags follow patterns like:
|
||||
- Topics: "machine-learning", "functional-programming"
|
||||
- Projects: "project:website-redesign"
|
||||
- Contexts: "context:work", "context:late-night"
|
||||
- Domains: "domain:finance"
|
||||
|
||||
Example:
|
||||
# Get all tags to ensure consistency
|
||||
tags = await get_all_tags()
|
||||
# Returns: ["ai-safety", "context:work", "functional-programming",
|
||||
# "machine-learning", "project:thesis", ...]
|
||||
|
||||
# Use to check if a topic has been discussed before
|
||||
if "quantum-computing" in tags:
|
||||
# Search for related observations
|
||||
observations = await search_observations(
|
||||
query="quantum computing",
|
||||
tags=["quantum-computing"]
|
||||
)
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""
|
||||
Get all unique subjects from observations about the user.
|
||||
|
||||
Purpose:
|
||||
This tool retrieves all subject identifiers that have been used in
|
||||
observations (created with 'observe'). Subjects are the consistent
|
||||
identifiers for what observations are about. Use this to understand
|
||||
what aspects of the user have been tracked and ensure consistency.
|
||||
|
||||
When to use:
|
||||
- Before creating new observations, to use existing subject names
|
||||
- To discover what aspects of the user have been observed
|
||||
- To build subject filters for targeted searches
|
||||
- To ensure consistent naming across observations
|
||||
- To get an overview of the user model
|
||||
|
||||
Returns:
|
||||
Sorted list of all unique subjects. Common patterns include:
|
||||
- "programming_style", "programming_philosophy"
|
||||
- "work_habits", "work_schedule"
|
||||
- "ai_beliefs", "ai_safety_beliefs"
|
||||
- "learning_preferences"
|
||||
- "communication_style"
|
||||
|
||||
Example:
|
||||
# Get all subjects to ensure consistency
|
||||
subjects = await get_all_subjects()
|
||||
# Returns: ["ai_safety_beliefs", "architecture_preferences",
|
||||
# "programming_philosophy", "work_schedule", ...]
|
||||
|
||||
# Use to check what we know about the user
|
||||
if "programming_style" in subjects:
|
||||
# Get all programming-related observations
|
||||
observations = await search_observations(
|
||||
query="programming",
|
||||
subject="programming_style"
|
||||
)
|
||||
|
||||
Best practices:
|
||||
- Always check existing subjects before creating new ones
|
||||
- Use snake_case for consistency
|
||||
- Be specific but not too granular
|
||||
- Group related observations under same subject
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""
|
||||
Get all unique observation types that have been used.
|
||||
|
||||
Purpose:
|
||||
This tool retrieves the distinct observation types that have been recorded
|
||||
in the system. While the standard types are predefined (belief, preference,
|
||||
behavior, contradiction, general), this shows what's actually been used.
|
||||
Helpful for understanding the distribution of observation types.
|
||||
|
||||
When to use:
|
||||
- To see what types of observations have been made
|
||||
- To understand the balance of different observation types
|
||||
- To check if all standard types are being utilized
|
||||
- For analytics or reporting on observation patterns
|
||||
|
||||
Standard types:
|
||||
- "belief": Opinions or beliefs the user holds
|
||||
- "preference": Things they prefer or favor
|
||||
- "behavior": Patterns in how they act or work
|
||||
- "contradiction": Noted inconsistencies
|
||||
- "general": Observations that don't fit other categories
|
||||
|
||||
Returns:
|
||||
List of observation types that have actually been used in the system.
|
||||
|
||||
Example:
|
||||
# Check what types of observations exist
|
||||
types = await get_all_observation_types()
|
||||
# Returns: ["behavior", "belief", "contradiction", "preference"]
|
||||
|
||||
# Use to analyze observation distribution
|
||||
for obs_type in types:
|
||||
observations = await search_observations(
|
||||
query="",
|
||||
observation_types=[obs_type],
|
||||
limit=100
|
||||
)
|
||||
print(f"{obs_type}: {len(observations)} observations")
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search through the user's stored knowledge and content.
|
||||
|
||||
Purpose:
|
||||
This tool searches the user's personal knowledge base - a collection of
|
||||
their saved content including emails, documents, blog posts, books, and
|
||||
more. Use this alongside 'search_observations' to build a complete picture:
|
||||
- search_knowledge_base: Finds user's actual content and information
|
||||
- search_observations: Finds AI-generated insights about the user
|
||||
Together they enable deeply personalized, context-aware assistance.
|
||||
|
||||
When to use:
|
||||
- User asks about something they've read/written/received
|
||||
- You need to find specific content the user has saved
|
||||
- User references a document, email, or article
|
||||
- To provide quotes or information from user's sources
|
||||
- To understand context from user's past communications
|
||||
- When user says "that article about..." or similar references
|
||||
|
||||
How it works:
|
||||
Uses hybrid search combining semantic understanding with keyword matching.
|
||||
This means it finds content based on meaning AND specific terms, giving
|
||||
you the best of both approaches. Results are ranked by relevance.
|
||||
|
||||
Args:
|
||||
query: Natural language search query. Be descriptive about what you're
|
||||
looking for. The search understands meaning but also values exact terms.
|
||||
Examples:
|
||||
- "email about project deadline from last week"
|
||||
- "functional programming articles comparing Haskell and Scala"
|
||||
- "that blog post about AI safety and alignment"
|
||||
- "recipe for chocolate cake Sarah sent me"
|
||||
Pro tip: Include both concepts and specific keywords for best results.
|
||||
|
||||
previews: Whether to include content snippets in results.
|
||||
- True: Returns preview text and image previews (useful for quick scanning)
|
||||
- False: Returns just metadata (faster, less data)
|
||||
Default is False.
|
||||
|
||||
modalities: Types of content to search. Leave empty to search all.
|
||||
Available types:
|
||||
- 'email': Email messages
|
||||
- 'blog': Blog posts and articles
|
||||
- 'book': Book sections and ebooks
|
||||
- 'forum': Forum posts (e.g., LessWrong, Reddit)
|
||||
- 'observation': AI observations (use search_observations instead)
|
||||
- 'photo': Images with extracted text
|
||||
- 'comic': Comics and graphic content
|
||||
- 'webpage': General web pages
|
||||
Examples:
|
||||
- ["email"] - only emails
|
||||
- ["blog", "forum"] - articles and forum posts
|
||||
- [] - search everything
|
||||
|
||||
limit: Maximum results to return (1-100). Default 10.
|
||||
Increase for comprehensive searches, decrease for quick lookups.
|
||||
|
||||
Returns:
|
||||
List of search results ranked by relevance, each containing:
|
||||
- id: Unique identifier for the source item
|
||||
- score: Relevance score (0-1, higher is better)
|
||||
- chunks: Matching content segments with metadata
|
||||
- content: Full details including:
|
||||
- For emails: sender, recipient, subject, date
|
||||
- For blogs: author, title, url, publish date
|
||||
- For books: title, author, chapter info
|
||||
- Type-specific fields for each modality
|
||||
- filename: Path to file if content is stored on disk
|
||||
|
||||
Examples:
|
||||
# Find specific email
|
||||
results = await search_knowledge_base(
|
||||
query="Sarah deadline project proposal next Friday",
|
||||
modalities=["email"],
|
||||
previews=True,
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Search for technical articles
|
||||
results = await search_knowledge_base(
|
||||
query="functional programming monads category theory",
|
||||
modalities=["blog", "book"],
|
||||
limit=20
|
||||
)
|
||||
|
||||
# Find everything about a topic
|
||||
results = await search_knowledge_base(
|
||||
query="machine learning deployment kubernetes docker",
|
||||
previews=True
|
||||
)
|
||||
|
||||
# Quick lookup of a remembered document
|
||||
results = await search_knowledge_base(
|
||||
query="tax forms 2023 accountant recommendations",
|
||||
modalities=["email"],
|
||||
limit=3
|
||||
)
|
||||
|
||||
Best practices:
|
||||
- Include context in queries ("email from Sarah" vs just "Sarah")
|
||||
- Use modalities to filter when you know the content type
|
||||
- Enable previews when you need to verify content before using
|
||||
- Combine with search_observations for complete context
|
||||
- Higher scores (>0.7) indicate strong matches
|
||||
- If no results, try broader queries or different phrasing
|
||||
"""
|
||||
logger.info(f"MCP search for: {query}")
|
||||
|
||||
upload_data = extract.extract_text(query)
|
||||
results = await search(
|
||||
upload_data,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
limit=limit,
|
||||
min_text_score=0.3,
|
||||
min_multimodal_score=0.3,
|
||||
)
|
||||
|
||||
# Convert SearchResult objects to dictionaries for MCP
|
||||
return [result.model_dump() for result in results]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def observe(
|
||||
content: str,
|
||||
subject: str,
|
||||
observation_type: str = "general",
|
||||
confidence: float = 0.8,
|
||||
evidence: dict | None = None,
|
||||
tags: list[str] | None = None,
|
||||
session_id: str | None = None,
|
||||
agent_model: str = "unknown",
|
||||
) -> dict:
|
||||
"""
|
||||
Record an observation about the user to build long-term understanding.
|
||||
|
||||
Purpose:
|
||||
This tool is part of a memory system designed to help AI agents build a
|
||||
deep, persistent understanding of users over time. Use it to record any
|
||||
notable information about the user's preferences, beliefs, behaviors, or
|
||||
characteristics. These observations accumulate to create a comprehensive
|
||||
model of the user that improves future interactions.
|
||||
|
||||
Quick Reference:
|
||||
# Most common patterns:
|
||||
observe(content="User prefers X over Y because...", subject="preferences", observation_type="preference")
|
||||
observe(content="User always/often does X when Y", subject="work_habits", observation_type="behavior")
|
||||
observe(content="User believes/thinks X about Y", subject="beliefs_on_topic", observation_type="belief")
|
||||
observe(content="User said X but previously said Y", subject="topic", observation_type="contradiction")
|
||||
|
||||
When to use:
|
||||
- User expresses a preference or opinion
|
||||
- You notice a behavioral pattern
|
||||
- User reveals information about their work/life/interests
|
||||
- You spot a contradiction with previous statements
|
||||
- Any insight that would help understand the user better in future
|
||||
|
||||
Important: Be an active observer. Don't wait to be asked - proactively record
|
||||
observations throughout conversations to build understanding.
|
||||
|
||||
Args:
|
||||
content: The observation itself. Be specific and detailed. Write complete
|
||||
thoughts that will make sense when read months later without context.
|
||||
Bad: "Likes FP"
|
||||
Good: "User strongly prefers functional programming paradigms, especially
|
||||
pure functions and immutability, considering them more maintainable"
|
||||
|
||||
subject: A consistent identifier for what this observation is about. Use
|
||||
snake_case and be consistent across observations to enable tracking.
|
||||
Examples:
|
||||
- "programming_style" (not "coding" or "development")
|
||||
- "work_habits" (not "productivity" or "work_patterns")
|
||||
- "ai_safety_beliefs" (not "AI" or "artificial_intelligence")
|
||||
|
||||
observation_type: Categorize the observation:
|
||||
- "belief": An opinion or belief the user holds
|
||||
- "preference": Something they prefer or favor
|
||||
- "behavior": A pattern in how they act or work
|
||||
- "contradiction": An inconsistency with previous observations
|
||||
- "general": Doesn't fit other categories
|
||||
|
||||
confidence: How certain you are (0.0-1.0):
|
||||
- 1.0: User explicitly stated this
|
||||
- 0.9: Strongly implied or demonstrated repeatedly
|
||||
- 0.8: Inferred with high confidence (default)
|
||||
- 0.7: Probable but with some uncertainty
|
||||
- 0.6 or below: Speculative, use sparingly
|
||||
|
||||
evidence: Supporting context as a dict. Include relevant details:
|
||||
- "quote": Exact words from the user
|
||||
- "context": What prompted this observation
|
||||
- "timestamp": When this was observed
|
||||
- "related_to": Connection to other topics
|
||||
Example: {
|
||||
"quote": "I always refactor to pure functions",
|
||||
"context": "Discussing code review practices"
|
||||
}
|
||||
|
||||
tags: Categorization labels. Use lowercase with hyphens. Common patterns:
|
||||
- Topics: "machine-learning", "web-development", "philosophy"
|
||||
- Projects: "project:website-redesign", "project:thesis"
|
||||
- Contexts: "context:work", "context:personal", "context:late-night"
|
||||
- Domains: "domain:finance", "domain:healthcare"
|
||||
|
||||
session_id: UUID string to group observations from the same conversation.
|
||||
Generate one UUID per conversation and reuse it for all observations
|
||||
in that conversation. Format: "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
agent_model: Which AI model made this observation (e.g., "claude-3-opus",
|
||||
"gpt-4", "claude-3.5-sonnet"). Helps track observation quality.
|
||||
|
||||
Returns:
|
||||
Dict with created observation details:
|
||||
- id: Unique identifier for reference
|
||||
- created_at: Timestamp of creation
|
||||
- subject: The subject as stored
|
||||
- observation_type: The type as stored
|
||||
- confidence: The confidence score
|
||||
- tags: List of applied tags
|
||||
|
||||
Examples:
|
||||
# After user mentions their coding philosophy
|
||||
await observe(
|
||||
content="User believes strongly in functional programming principles, "
|
||||
"particularly avoiding mutable state which they call 'the root "
|
||||
"of all evil'. They prioritize code purity over performance.",
|
||||
subject="programming_philosophy",
|
||||
observation_type="belief",
|
||||
confidence=0.95,
|
||||
evidence={
|
||||
"quote": "State is the root of all evil in programming",
|
||||
"context": "Discussing why they chose Haskell for their project"
|
||||
},
|
||||
tags=["programming", "functional-programming", "philosophy"],
|
||||
session_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
agent_model="claude-3-opus"
|
||||
)
|
||||
|
||||
# Noticing a work pattern
|
||||
await observe(
|
||||
content="User frequently works on complex problems late at night, "
|
||||
"typically between 11pm and 3am, claiming better focus",
|
||||
subject="work_schedule",
|
||||
observation_type="behavior",
|
||||
confidence=0.85,
|
||||
evidence={
|
||||
"context": "Mentioned across multiple conversations over 2 weeks"
|
||||
},
|
||||
tags=["behavior", "work-habits", "productivity", "context:late-night"],
|
||||
agent_model="claude-3-opus"
|
||||
)
|
||||
|
||||
# Recording a contradiction
|
||||
await observe(
|
||||
content="User now advocates for microservices architecture, but "
|
||||
"previously argued strongly for monoliths in similar contexts",
|
||||
subject="architecture_preferences",
|
||||
observation_type="contradiction",
|
||||
confidence=0.9,
|
||||
evidence={
|
||||
"quote": "Microservices are definitely the way to go",
|
||||
"context": "Designing a new system similar to one from 3 months ago"
|
||||
},
|
||||
tags=["architecture", "contradiction", "software-design"],
|
||||
agent_model="gpt-4"
|
||||
)
|
||||
"""
|
||||
# Create the observation
|
||||
observation = AgentObservation(
|
||||
content=content,
|
||||
subject=subject,
|
||||
observation_type=observation_type,
|
||||
confidence=confidence,
|
||||
evidence=evidence,
|
||||
tags=tags or [],
|
||||
session_id=uuid.UUID(session_id) if session_id else None,
|
||||
agent_model=agent_model,
|
||||
size=len(content),
|
||||
mime_type="text/plain",
|
||||
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(),
|
||||
inserted_at=datetime.now(timezone.utc),
|
||||
)
|
||||
try:
|
||||
with make_session() as session:
|
||||
process_content_item(observation, session)
|
||||
|
||||
if not observation.id:
|
||||
raise ValueError("Observation not created")
|
||||
|
||||
logger.info(
|
||||
f"Observation created: {observation.id}, {observation.inserted_at}"
|
||||
)
|
||||
return {
|
||||
"id": observation.id,
|
||||
"created_at": observation.inserted_at.isoformat(),
|
||||
"subject": observation.subject,
|
||||
"observation_type": observation.observation_type,
|
||||
"confidence": cast(float, observation.confidence),
|
||||
"tags": observation.tags,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating observation: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_observations(
|
||||
query: str,
|
||||
subject: str = "",
|
||||
tags: list[str] | None = None,
|
||||
observation_types: list[str] | None = None,
|
||||
min_confidence: float = 0.5,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search through observations to understand the user better.
|
||||
|
||||
Purpose:
|
||||
This tool searches through all observations recorded about the user using
|
||||
the 'observe' tool. Use it to recall past insights, check for patterns,
|
||||
find contradictions, or understand the user's preferences before responding.
|
||||
The more you use this tool, the more personalized and insightful your
|
||||
responses can be.
|
||||
|
||||
When to use:
|
||||
- Before answering questions where user preferences might matter
|
||||
- When the user references something from the past
|
||||
- To check if current behavior aligns with past patterns
|
||||
- To find related observations on a topic
|
||||
- To build context about the user's expertise or interests
|
||||
- Whenever personalization would improve your response
|
||||
|
||||
How it works:
|
||||
Uses hybrid search combining semantic similarity with keyword matching.
|
||||
Searches across multiple embedding spaces (semantic meaning and temporal
|
||||
context) to find relevant observations from different angles. This approach
|
||||
ensures you find both conceptually related and specifically mentioned items.
|
||||
|
||||
Args:
|
||||
query: Natural language description of what you're looking for. The search
|
||||
matches both meaning and specific terms in observation content.
|
||||
Examples:
|
||||
- "programming preferences and coding style"
|
||||
- "opinions about artificial intelligence and AI safety"
|
||||
- "work habits productivity patterns when does user work best"
|
||||
- "previous projects the user has worked on"
|
||||
Pro tip: Use natural language but include key terms you expect to find.
|
||||
|
||||
subject: Filter by exact subject identifier. Must match subjects used when
|
||||
creating observations (e.g., "programming_style", "work_habits").
|
||||
Leave empty to search all subjects. Use this when you know the exact
|
||||
subject category you want.
|
||||
|
||||
tags: Filter results to only observations with these tags. Observations must
|
||||
have at least one matching tag. Use the same format as when creating:
|
||||
- ["programming", "functional-programming"]
|
||||
- ["context:work", "project:thesis"]
|
||||
- ["domain:finance", "machine-learning"]
|
||||
|
||||
observation_types: Filter by type of observation:
|
||||
- "belief": Opinions or beliefs the user holds
|
||||
- "preference": Things they prefer or favor
|
||||
- "behavior": Patterns in how they act or work
|
||||
- "contradiction": Noted inconsistencies
|
||||
- "general": Other observations
|
||||
Leave as None to search all types.
|
||||
|
||||
min_confidence: Only return observations with confidence >= this value.
|
||||
- Use 0.8+ for high-confidence facts
|
||||
- Use 0.5-0.7 to include inferred observations
|
||||
- Default 0.5 includes most observations
|
||||
Range: 0.0 to 1.0
|
||||
|
||||
limit: Maximum results to return (1-100). Default 10. Increase when you
|
||||
need comprehensive understanding of a topic.
|
||||
|
||||
Returns:
|
||||
List of observations sorted by relevance, each containing:
|
||||
- subject: What the observation is about
|
||||
- content: The full observation text
|
||||
- observation_type: Type of observation
|
||||
- evidence: Supporting context/quotes if provided
|
||||
- confidence: How certain the observation is (0-1)
|
||||
- agent_model: Which AI model made the observation
|
||||
- tags: All tags on this observation
|
||||
- created_at: When it was observed (if available)
|
||||
|
||||
Examples:
|
||||
# Before discussing code architecture
|
||||
results = await search_observations(
|
||||
query="software architecture preferences microservices monoliths",
|
||||
tags=["architecture"],
|
||||
min_confidence=0.7
|
||||
)
|
||||
|
||||
# Understanding work style for scheduling
|
||||
results = await search_observations(
|
||||
query="when does user work best productivity schedule",
|
||||
observation_types=["behavior", "preference"],
|
||||
subject="work_schedule"
|
||||
)
|
||||
|
||||
# Check for AI safety views before discussing AI
|
||||
results = await search_observations(
|
||||
query="artificial intelligence safety alignment concerns",
|
||||
observation_types=["belief"],
|
||||
min_confidence=0.8,
|
||||
limit=20
|
||||
)
|
||||
|
||||
# Find contradictions on a topic
|
||||
results = await search_observations(
|
||||
query="testing methodology unit tests integration",
|
||||
observation_types=["contradiction"],
|
||||
tags=["testing", "software-development"]
|
||||
)
|
||||
|
||||
Best practices:
|
||||
- Search before making assumptions about user preferences
|
||||
- Use broad queries first, then filter with tags/types if too many results
|
||||
- Check for contradictions when user says something unexpected
|
||||
- Higher confidence observations are more reliable
|
||||
- Recent observations may override older ones on same topic
|
||||
"""
|
||||
source_ids = None
|
||||
if tags or observation_types:
|
||||
with make_session() as session:
|
||||
items_query = session.query(AgentObservation.id)
|
||||
|
||||
if tags:
|
||||
# Use PostgreSQL array overlap operator with proper array casting
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if observation_types:
|
||||
items_query = items_query.filter(
|
||||
AgentObservation.observation_type.in_(observation_types)
|
||||
)
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
semantic_text = observation.generate_semantic_text(
|
||||
subject=subject or "",
|
||||
observation_type="".join(observation_types or []),
|
||||
content=query,
|
||||
evidence=None,
|
||||
)
|
||||
temporal = observation.generate_temporal_text(
|
||||
subject=subject or "",
|
||||
content=query,
|
||||
confidence=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
results = await search(
|
||||
[
|
||||
extract.DataChunk(data=[query]),
|
||||
extract.DataChunk(data=[semantic_text]),
|
||||
extract.DataChunk(data=[temporal]),
|
||||
],
|
||||
previews=True,
|
||||
modalities=["semantic", "temporal"],
|
||||
limit=limit,
|
||||
min_text_score=0.8,
|
||||
filters=SearchFilters(
|
||||
subject=subject,
|
||||
confidence=min_confidence,
|
||||
tags=tags,
|
||||
observation_types=observation_types,
|
||||
source_ids=source_ids,
|
||||
),
|
||||
)
|
||||
|
||||
return [
|
||||
cast(dict, cast(dict, result.model_dump()).get("content")) for result in results
|
||||
]
|
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
ForumPost,
|
||||
AgentObservation,
|
||||
)
|
||||
|
||||
|
||||
@ -53,6 +54,7 @@ class SourceItemAdmin(ModelView, model=SourceItem):
|
||||
|
||||
class ChunkAdmin(ModelView, model=Chunk):
|
||||
column_list = ["id", "source_id", "embedding_model", "created_at"]
|
||||
column_sortable_list = ["created_at"]
|
||||
|
||||
|
||||
class MailMessageAdmin(ModelView, model=MailMessage):
|
||||
@ -174,9 +176,23 @@ class EmailAccountAdmin(ModelView, model=EmailAccount):
|
||||
column_searchable_list = ["name", "email_address"]
|
||||
|
||||
|
||||
class AgentObservationAdmin(ModelView, model=AgentObservation):
|
||||
column_list = [
|
||||
"id",
|
||||
"content",
|
||||
"subject",
|
||||
"observation_type",
|
||||
"confidence",
|
||||
"evidence",
|
||||
"inserted_at",
|
||||
]
|
||||
column_searchable_list = ["subject", "observation_type"]
|
||||
|
||||
|
||||
def setup_admin(admin: Admin):
|
||||
"""Add all admin views to the admin instance."""
|
||||
admin.add_view(SourceItemAdmin)
|
||||
admin.add_view(AgentObservationAdmin)
|
||||
admin.add_view(ChunkAdmin)
|
||||
admin.add_view(EmailAccountAdmin)
|
||||
admin.add_view(MailMessageAdmin)
|
||||
|
@ -2,6 +2,7 @@
|
||||
FastAPI application for the knowledge base.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import pathlib
|
||||
import logging
|
||||
from typing import Annotated, Optional
|
||||
@ -15,15 +16,25 @@ from memory.common import extract
|
||||
from memory.common.db.connection import get_engine
|
||||
from memory.api.admin import setup_admin
|
||||
from memory.api.search import search, SearchResult
|
||||
from memory.api.MCP.tools import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Knowledge Base API")
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(mcp.session_manager.run())
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||
|
||||
# SQLAdmin setup
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
setup_admin(admin)
|
||||
app.mount("/", mcp.streamable_http_app())
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@ -104,4 +115,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from memory.common.qdrant import setup_qdrant
|
||||
|
||||
setup_qdrant()
|
||||
main()
|
||||
|
@ -2,21 +2,29 @@
|
||||
Search endpoints for the knowledge base API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from hashlib import sha256
|
||||
import io
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, TypedDict, NotRequired
|
||||
|
||||
import bm25s
|
||||
import Stemmer
|
||||
import qdrant_client
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
|
||||
from memory.common import embedding, qdrant, extract, settings
|
||||
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
|
||||
from memory.common import embedding, extract, qdrant, settings
|
||||
from memory.common.collections import (
|
||||
ALL_COLLECTIONS,
|
||||
MULTIMODAL_COLLECTIONS,
|
||||
TEXT_COLLECTIONS,
|
||||
)
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.db.models import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -28,6 +36,17 @@ class AnnotatedChunk(BaseModel):
|
||||
preview: Optional[str | None] = None
|
||||
|
||||
|
||||
class SourceData(BaseModel):
|
||||
"""Holds source item data to avoid SQLAlchemy session issues"""
|
||||
|
||||
id: int
|
||||
size: int | None
|
||||
mime_type: str | None
|
||||
filename: str | None
|
||||
content: str | dict | None
|
||||
content_length: int
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
collection: str
|
||||
results: list[dict]
|
||||
@ -38,16 +57,46 @@ class SearchResult(BaseModel):
|
||||
size: int
|
||||
mime_type: str
|
||||
chunks: list[AnnotatedChunk]
|
||||
content: Optional[str] = None
|
||||
content: Optional[str | dict] = None
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
class SearchFilters(TypedDict):
|
||||
subject: NotRequired[str | None]
|
||||
confidence: NotRequired[float]
|
||||
tags: NotRequired[list[str] | None]
|
||||
observation_types: NotRequired[list[str] | None]
|
||||
source_ids: NotRequired[list[int] | None]
|
||||
|
||||
|
||||
async def with_timeout(
|
||||
call, timeout: int = 2
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
Run a function with a timeout.
|
||||
|
||||
Args:
|
||||
call: The function to run
|
||||
timeout: The timeout in seconds
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(call, timeout=timeout)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Search timed out after {timeout}s")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def annotated_chunk(
|
||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||
) -> tuple[SourceItem, AnnotatedChunk]:
|
||||
) -> tuple[SourceData, AnnotatedChunk]:
|
||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||
if not previews and not isinstance(item, str):
|
||||
return None
|
||||
if not previews and isinstance(item, str):
|
||||
return item[:100]
|
||||
|
||||
if isinstance(item, Image.Image):
|
||||
buffer = io.BytesIO()
|
||||
@ -68,7 +117,19 @@ def annotated_chunk(
|
||||
for k, v in metadata.items()
|
||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||
}
|
||||
return chunk.source, AnnotatedChunk(
|
||||
|
||||
# Prefetch all needed source data while in session
|
||||
source = chunk.source
|
||||
source_data = SourceData(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
filename=source.filename,
|
||||
content=source.display_contents,
|
||||
content_length=len(source.content) if source.content else 0,
|
||||
)
|
||||
|
||||
return source_data, AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=search_result.score,
|
||||
metadata=metadata,
|
||||
@ -76,24 +137,28 @@ def annotated_chunk(
|
||||
)
|
||||
|
||||
|
||||
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
||||
def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]:
|
||||
items = defaultdict(list)
|
||||
source_lookup = {}
|
||||
|
||||
for source, chunk in chunks:
|
||||
items[source].append(chunk)
|
||||
items[source.id].append(chunk)
|
||||
source_lookup[source.id] = source
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
id=source.id,
|
||||
size=source.size or len(source.content),
|
||||
size=source.size or source.content_length,
|
||||
mime_type=source.mime_type or "text/plain",
|
||||
filename=source.filename
|
||||
and source.filename.replace(
|
||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||
),
|
||||
content=source.display_contents,
|
||||
content=source.content,
|
||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||
)
|
||||
for source, chunks in items.items()
|
||||
for source_id, chunks in items.items()
|
||||
for source in [source_lookup[source_id]]
|
||||
]
|
||||
|
||||
|
||||
@ -104,25 +169,31 @@ def query_chunks(
|
||||
embedder: Callable,
|
||||
min_score: float = 0.0,
|
||||
limit: int = 10,
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||
if not upload_data:
|
||||
if not upload_data or not allowed_modalities:
|
||||
return {}
|
||||
|
||||
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
||||
chunks = [chunk for chunk in upload_data if chunk.data]
|
||||
if not chunks:
|
||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||
return {}
|
||||
|
||||
vector = embedder(chunks, input_type="query")[0]
|
||||
logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}")
|
||||
for c in chunks:
|
||||
logger.error(f"Chunk: {c.data}")
|
||||
vectors = embedder([c.data for c in chunks], input_type="query")
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for vector in vectors
|
||||
for r in qdrant.search_vectors(
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
filter_params=filters,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
@ -130,6 +201,122 @@ def query_chunks(
|
||||
}
|
||||
|
||||
|
||||
async def search_bm25(
|
||||
query: str,
|
||||
modalities: list[str],
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
with make_session() as db:
|
||||
items_query = db.query(Chunk.id, Chunk.content).filter(
|
||||
Chunk.collection_name.in_(modalities)
|
||||
)
|
||||
if source_ids := filters.get("source_ids"):
|
||||
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||
items = items_query.all()
|
||||
item_ids = {
|
||||
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
|
||||
for item in items
|
||||
}
|
||||
corpus = [item.content.lower().strip() for item in items]
|
||||
|
||||
stemmer = Stemmer.Stemmer("english")
|
||||
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
|
||||
retriever = bm25s.BM25()
|
||||
retriever.index(corpus_tokens)
|
||||
|
||||
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
|
||||
results, scores = retriever.retrieve(
|
||||
query_tokens, k=min(limit, len(corpus)), corpus=corpus
|
||||
)
|
||||
|
||||
item_scores = {
|
||||
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
|
||||
for doc, score in zip(results[0], scores[0])
|
||||
}
|
||||
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
# Prefetch all needed source data while in session
|
||||
source = chunk.source
|
||||
source_data = SourceData(
|
||||
id=source.id,
|
||||
size=source.size,
|
||||
mime_type=source.mime_type,
|
||||
filename=source.filename,
|
||||
content=source.display_contents,
|
||||
content_length=len(source.content) if source.content else 0,
|
||||
)
|
||||
|
||||
annotated = AnnotatedChunk(
|
||||
id=str(chunk.id),
|
||||
score=item_scores[chunk.id],
|
||||
metadata=source.as_payload(),
|
||||
preview=None,
|
||||
)
|
||||
results.append((source_data, annotated))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_embeddings(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
min_score: float = 0.3,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
multimodal: bool = False,
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
|
||||
Parameters:
|
||||
- data: List of data to search in (e.g., text, images, files)
|
||||
- previews: Whether to include previews in the search results
|
||||
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||
- limit: Maximum number of results
|
||||
- min_score: Minimum score to include in the search results
|
||||
- filters: Filters to apply to the search results
|
||||
- multimodal: Whether to search in multimodal collections
|
||||
"""
|
||||
query_filters = {
|
||||
"must": [
|
||||
{"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}},
|
||||
],
|
||||
}
|
||||
if tags := filters.get("tags"):
|
||||
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
|
||||
if observation_types := filters.get("observation_types"):
|
||||
query_filters["must"] += [
|
||||
{"key": "observation_type", "match": {"any": observation_types}}
|
||||
]
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
modalities,
|
||||
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||
min_score=min_score,
|
||||
limit=limit,
|
||||
filters=query_filters,
|
||||
)
|
||||
search_results = {k: results.get(k, []) for k in modalities}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
return [
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
|
||||
async def search(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
@ -137,6 +324,7 @@ async def search(
|
||||
limit: int = 10,
|
||||
min_text_score: float = 0.3,
|
||||
min_multimodal_score: float = 0.3,
|
||||
filters: SearchFilters = {},
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
@ -150,40 +338,45 @@ async def search(
|
||||
Returns:
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||
text_results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
embedding.embed_text,
|
||||
min_score=min_text_score,
|
||||
limit=limit,
|
||||
)
|
||||
multimodal_results = query_chunks(
|
||||
client,
|
||||
data,
|
||||
allowed_modalities,
|
||||
embedding.embed_mixed,
|
||||
min_score=min_multimodal_score,
|
||||
limit=limit,
|
||||
)
|
||||
search_results = {
|
||||
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
||||
for k in allowed_modalities
|
||||
}
|
||||
|
||||
found_chunks = {
|
||||
str(r.id): r for results in search_results.values() for r in results
|
||||
}
|
||||
with make_session() as db:
|
||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||
logger.error(f"Found chunks: {chunks}")
|
||||
|
||||
results = group_chunks(
|
||||
[
|
||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||
for chunk in chunks
|
||||
]
|
||||
text_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
min_text_score,
|
||||
filters,
|
||||
multimodal=False,
|
||||
)
|
||||
)
|
||||
multimodal_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
min_multimodal_score,
|
||||
filters,
|
||||
multimodal=True,
|
||||
)
|
||||
)
|
||||
bm25_results = with_timeout(
|
||||
search_bm25(
|
||||
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
||||
modalities,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
text_embeddings_results,
|
||||
multimodal_embeddings_results,
|
||||
bm25_results,
|
||||
return_exceptions=False,
|
||||
)
|
||||
|
||||
results = group_chunks([c for r in results for c in r])
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
@ -14,73 +15,92 @@ Vector = list[float]
|
||||
class Collection(TypedDict):
|
||||
dimension: int
|
||||
distance: DistanceType
|
||||
model: str
|
||||
on_disk: NotRequired[bool]
|
||||
shards: NotRequired[int]
|
||||
text: bool
|
||||
multimodal: bool
|
||||
|
||||
|
||||
ALL_COLLECTIONS: dict[str, Collection] = {
|
||||
"mail": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
"chat": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": True,
|
||||
},
|
||||
"git": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
"book": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
"blog": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": True,
|
||||
},
|
||||
"forum": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": True,
|
||||
},
|
||||
"text": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
# Multimodal
|
||||
"photo": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": False,
|
||||
"multimodal": True,
|
||||
},
|
||||
"comic": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": False,
|
||||
"multimodal": True,
|
||||
},
|
||||
"doc": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
||||
"text": False,
|
||||
"multimodal": True,
|
||||
},
|
||||
# Observations
|
||||
"semantic": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
"temporal": {
|
||||
"dimension": 1024,
|
||||
"distance": "Cosine",
|
||||
"text": True,
|
||||
"multimodal": False,
|
||||
},
|
||||
}
|
||||
TEXT_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in ALL_COLLECTIONS.items()
|
||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
||||
coll for coll, params in ALL_COLLECTIONS.items() if params.get("text")
|
||||
}
|
||||
MULTIMODAL_COLLECTIONS = {
|
||||
coll
|
||||
for coll, params in ALL_COLLECTIONS.items()
|
||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
||||
coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal")
|
||||
}
|
||||
|
||||
TYPES = {
|
||||
@ -108,5 +128,12 @@ def get_modality(mime_type: str) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def collection_model(collection: str) -> str | None:
|
||||
return ALL_COLLECTIONS.get(collection, {}).get("model")
|
||||
def collection_model(
|
||||
collection: str, text: str, images: list[Image.Image]
|
||||
) -> str | None:
|
||||
config = ALL_COLLECTIONS.get(collection, {})
|
||||
if images and config.get("multimodal"):
|
||||
return settings.MIXED_EMBEDDING_MODEL
|
||||
if text and config.get("text"):
|
||||
return settings.TEXT_EMBEDDING_MODEL
|
||||
return "unknown"
|
||||
|
@ -2,7 +2,7 @@
|
||||
Database utilities package.
|
||||
"""
|
||||
|
||||
from memory.common.db.models import Base
|
||||
from memory.common.db.models.base import Base
|
||||
from memory.common.db.connection import (
|
||||
get_engine,
|
||||
get_session_factory,
|
||||
|
61
src/memory/common/db/models/__init__.py
Normal file
61
src/memory/common/db/models/__init__.py
Normal file
@ -0,0 +1,61 @@
|
||||
from memory.common.db.models.base import Base
|
||||
from memory.common.db.models.source_item import (
|
||||
Chunk,
|
||||
SourceItem,
|
||||
clean_filename,
|
||||
)
|
||||
from memory.common.db.models.source_items import (
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
AgentObservation,
|
||||
ChatMessage,
|
||||
BlogPost,
|
||||
Comic,
|
||||
BookSection,
|
||||
ForumPost,
|
||||
GithubItem,
|
||||
GitCommit,
|
||||
Photo,
|
||||
MiscDoc,
|
||||
)
|
||||
from memory.common.db.models.observations import (
|
||||
ObservationContradiction,
|
||||
ReactionPattern,
|
||||
ObservationPattern,
|
||||
BeliefCluster,
|
||||
ConversationMetrics,
|
||||
)
|
||||
from memory.common.db.models.sources import (
|
||||
Book,
|
||||
ArticleFeed,
|
||||
EmailAccount,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"Chunk",
|
||||
"clean_filename",
|
||||
"SourceItem",
|
||||
"MailMessage",
|
||||
"EmailAttachment",
|
||||
"AgentObservation",
|
||||
"ChatMessage",
|
||||
"BlogPost",
|
||||
"Comic",
|
||||
"BookSection",
|
||||
"ForumPost",
|
||||
"GithubItem",
|
||||
"GitCommit",
|
||||
"Photo",
|
||||
"MiscDoc",
|
||||
# Observations
|
||||
"ObservationContradiction",
|
||||
"ReactionPattern",
|
||||
"ObservationPattern",
|
||||
"BeliefCluster",
|
||||
"ConversationMetrics",
|
||||
# Sources
|
||||
"Book",
|
||||
"ArticleFeed",
|
||||
"EmailAccount",
|
||||
]
|
4
src/memory/common/db/models/base.py
Normal file
4
src/memory/common/db/models/base.py
Normal file
@ -0,0 +1,4 @@
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
|
||||
Base = declarative_base()
|
171
src/memory/common/db/models/observations.py
Normal file
171
src/memory/common/db/models/observations.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""
|
||||
Agent observation models for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
Text,
|
||||
UUID,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class ObservationContradiction(Base):
|
||||
"""
|
||||
Tracks contradictions between observations.
|
||||
Can be detected automatically or reported by agents.
|
||||
"""
|
||||
|
||||
__tablename__ = "observation_contradiction"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
observation_1_id = Column(
|
||||
BigInteger,
|
||||
ForeignKey("agent_observation.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
observation_2_id = Column(
|
||||
BigInteger,
|
||||
ForeignKey("agent_observation.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
contradiction_type = Column(Text, nullable=False) # direct, implied, temporal
|
||||
detected_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
detection_method = Column(Text, nullable=False) # manual, automatic, agent-reported
|
||||
resolution = Column(Text) # How it was resolved, if at all
|
||||
observation_metadata = Column(JSONB)
|
||||
|
||||
# Relationships - use string references to avoid circular imports
|
||||
observation_1 = relationship(
|
||||
"AgentObservation",
|
||||
foreign_keys=[observation_1_id],
|
||||
back_populates="contradictions_as_first",
|
||||
)
|
||||
observation_2 = relationship(
|
||||
"AgentObservation",
|
||||
foreign_keys=[observation_2_id],
|
||||
back_populates="contradictions_as_second",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("obs_contra_obs1_idx", "observation_1_id"),
|
||||
Index("obs_contra_obs2_idx", "observation_2_id"),
|
||||
Index("obs_contra_type_idx", "contradiction_type"),
|
||||
Index("obs_contra_method_idx", "detection_method"),
|
||||
)
|
||||
|
||||
|
||||
class ReactionPattern(Base):
|
||||
"""
|
||||
Tracks patterns in how the user reacts to certain types of observations or challenges.
|
||||
"""
|
||||
|
||||
__tablename__ = "reaction_pattern"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
trigger_type = Column(
|
||||
Text, nullable=False
|
||||
) # What kind of observation triggers this
|
||||
reaction_type = Column(Text, nullable=False) # How user typically responds
|
||||
frequency = Column(Numeric(3, 2), nullable=False) # How often this pattern appears
|
||||
first_observed = Column(DateTime(timezone=True), server_default=func.now())
|
||||
last_observed = Column(DateTime(timezone=True), server_default=func.now())
|
||||
example_observations = Column(
|
||||
ARRAY(BigInteger)
|
||||
) # IDs of observations showing this pattern
|
||||
reaction_metadata = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
Index("reaction_trigger_idx", "trigger_type"),
|
||||
Index("reaction_type_idx", "reaction_type"),
|
||||
Index("reaction_frequency_idx", "frequency"),
|
||||
)
|
||||
|
||||
|
||||
class ObservationPattern(Base):
|
||||
"""
|
||||
Higher-level patterns detected across multiple observations.
|
||||
"""
|
||||
|
||||
__tablename__ = "observation_pattern"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
pattern_type = Column(Text, nullable=False) # behavioral, cognitive, emotional
|
||||
description = Column(Text, nullable=False)
|
||||
supporting_observations = Column(
|
||||
ARRAY(BigInteger), nullable=False
|
||||
) # Observation IDs
|
||||
exceptions = Column(ARRAY(BigInteger)) # Observations that don't fit
|
||||
confidence = Column(Numeric(3, 2), nullable=False, default=0.7)
|
||||
validity_start = Column(DateTime(timezone=True))
|
||||
validity_end = Column(DateTime(timezone=True))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
pattern_metadata = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
Index("obs_pattern_type_idx", "pattern_type"),
|
||||
Index("obs_pattern_confidence_idx", "confidence"),
|
||||
Index("obs_pattern_validity_idx", "validity_start", "validity_end"),
|
||||
)
|
||||
|
||||
|
||||
class BeliefCluster(Base):
|
||||
"""
|
||||
Groups of related beliefs that support or depend on each other.
|
||||
"""
|
||||
|
||||
__tablename__ = "belief_cluster"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
cluster_name = Column(Text, nullable=False)
|
||||
core_beliefs = Column(ARRAY(Text), nullable=False)
|
||||
peripheral_beliefs = Column(ARRAY(Text))
|
||||
internal_consistency = Column(Numeric(3, 2)) # How well beliefs align
|
||||
supporting_observations = Column(ARRAY(BigInteger)) # Observation IDs
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
last_updated = Column(DateTime(timezone=True), server_default=func.now())
|
||||
cluster_metadata = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
Index("belief_cluster_name_idx", "cluster_name"),
|
||||
Index("belief_cluster_consistency_idx", "internal_consistency"),
|
||||
)
|
||||
|
||||
|
||||
class ConversationMetrics(Base):
|
||||
"""
|
||||
Tracks the effectiveness and depth of conversations.
|
||||
"""
|
||||
|
||||
__tablename__ = "conversation_metrics"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
session_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
depth_score = Column(Numeric(3, 2)) # How deep the conversation went
|
||||
breakthrough_count = Column(Integer, default=0)
|
||||
challenge_acceptance = Column(Numeric(3, 2)) # How well challenges were received
|
||||
new_insights = Column(Integer, default=0)
|
||||
user_engagement = Column(Numeric(3, 2)) # Inferred engagement level
|
||||
duration_minutes = Column(Integer)
|
||||
observation_count = Column(Integer, default=0)
|
||||
contradiction_count = Column(Integer, default=0)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
metrics_metadata = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
Index("conv_metrics_session_idx", "session_id", unique=True),
|
||||
Index("conv_metrics_depth_idx", "depth_score"),
|
||||
Index("conv_metrics_breakthrough_idx", "breakthrough_count"),
|
||||
)
|
294
src/memory/common/db/models/source_item.py
Normal file
294
src/memory/common/db/models/source_item.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""
|
||||
Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any, Sequence, cast
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
UUID,
|
||||
BigInteger,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
event,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import BYTEA
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
from memory.common import settings
|
||||
import memory.common.extract as extract
|
||||
import memory.common.collections as collections
|
||||
import memory.common.chunker as chunker
|
||||
import memory.common.summarizer as summarizer
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def handle_duplicate_sha256(session, flush_context, instances):
|
||||
"""
|
||||
Event listener that efficiently checks for duplicate sha256 values before flush
|
||||
and removes items with duplicate sha256 from the session.
|
||||
|
||||
Uses a single query to identify all duplicates rather than querying for each item.
|
||||
"""
|
||||
# Find all SourceItem objects being added
|
||||
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
items = {}
|
||||
for item in new_items:
|
||||
try:
|
||||
if (sha256 := item.sha256) is None:
|
||||
continue
|
||||
|
||||
if sha256 in items:
|
||||
session.expunge(item)
|
||||
continue
|
||||
|
||||
items[sha256] = item
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
# Query database for existing items with these sha256 values in a single query
|
||||
existing_sha256s = set(
|
||||
row[0]
|
||||
for row in session.query(SourceItem.sha256).filter(
|
||||
SourceItem.sha256.in_(items.keys())
|
||||
)
|
||||
)
|
||||
|
||||
# Remove objects with duplicate sha256 values from the session
|
||||
for sha256 in existing_sha256s:
|
||||
if sha256 in items:
|
||||
session.expunge(items[sha256])
|
||||
|
||||
|
||||
def clean_filename(filename: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
||||
|
||||
|
||||
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
|
||||
for i, image in enumerate(images):
|
||||
if not image.filename: # type: ignore
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
|
||||
image.save(filename)
|
||||
image.filename = str(filename) # type: ignore
|
||||
|
||||
return [image.filename for image in images] # type: ignore
|
||||
|
||||
|
||||
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
|
||||
return [chunk] + [
|
||||
i
|
||||
for i in images
|
||||
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
|
||||
]
|
||||
|
||||
|
||||
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
final = {}
|
||||
for m in metadata:
|
||||
data = m.copy()
|
||||
if tags := set(data.pop("tags", [])):
|
||||
final["tags"] = tags | final.get("tags", set())
|
||||
final |= data
|
||||
return final
|
||||
|
||||
|
||||
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||
|
||||
summary, tags = summarizer.summarize(content)
|
||||
full_text: extract.DataChunk = extract.DataChunk(
|
||||
data=[content.strip(), *images], metadata={"tags": tags}
|
||||
)
|
||||
|
||||
chunks: list[extract.DataChunk] = [full_text]
|
||||
tokens = chunker.approx_token_count(content)
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks += [
|
||||
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
|
||||
for c in chunker.chunk_text(content)
|
||||
]
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
return [c for c in chunks if c.data]
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
"""Stores content chunks with their vector embeddings."""
|
||||
|
||||
__tablename__ = "chunk"
|
||||
|
||||
# The ID is also used as the vector ID in the vector database
|
||||
id = Column(
|
||||
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
|
||||
)
|
||||
source_id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
file_paths = Column(
|
||||
ARRAY(Text), nullable=True
|
||||
) # Path to content if stored as a file
|
||||
content = Column(Text) # Direct content storage
|
||||
embedding_model = Column(Text)
|
||||
collection_name = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
vector: list[float] = []
|
||||
item_metadata: dict[str, Any] = {}
|
||||
images: list[Image.Image] = []
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
||||
Index("chunk_source_idx", "source_id"),
|
||||
)
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[extract.MulitmodalChunk]:
|
||||
chunks: list[extract.MulitmodalChunk] = []
|
||||
if cast(str | None, self.content):
|
||||
chunks = [cast(str, self.content)]
|
||||
if self.images:
|
||||
chunks += self.images
|
||||
elif cast(Sequence[str] | None, self.file_paths):
|
||||
chunks += [
|
||||
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
|
||||
]
|
||||
return chunks
|
||||
|
||||
@property
|
||||
def data(self) -> list[bytes | str | Image.Image]:
|
||||
if self.file_paths is None:
|
||||
return [cast(str, self.content)]
|
||||
|
||||
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
|
||||
files = [path for path in paths if path.exists()]
|
||||
|
||||
items = []
|
||||
for file_path in files:
|
||||
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||
if file_path.exists():
|
||||
items.append(Image.open(file_path))
|
||||
elif file_path.suffix == ".bin":
|
||||
items.append(file_path.read_bytes())
|
||||
else:
|
||||
items.append(file_path.read_text())
|
||||
return items
|
||||
|
||||
|
||||
class SourceItem(Base):
|
||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||
|
||||
__tablename__ = "source_item"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
modality = Column(Text, nullable=False)
|
||||
sha256 = Column(BYTEA, nullable=False, unique=True)
|
||||
inserted_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), default=func.now()
|
||||
)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
size = Column(Integer)
|
||||
mime_type = Column(Text)
|
||||
|
||||
# Content is stored in the database if it's small enough and text
|
||||
content = Column(Text)
|
||||
# Otherwise the content is stored on disk
|
||||
filename = Column(Text, nullable=True)
|
||||
|
||||
# Chunks relationship
|
||||
embed_status = Column(Text, nullable=False, server_default="RAW")
|
||||
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
||||
|
||||
# Discriminator column for SQLAlchemy inheritance
|
||||
type = Column(String(50))
|
||||
|
||||
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
|
||||
|
||||
# Add table-level constraint and indexes
|
||||
__table_args__ = (
|
||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
Index("source_modality_idx", "modality"),
|
||||
Index("source_status_idx", "embed_status"),
|
||||
Index("source_tags_idx", "tags", postgresql_using="gin"),
|
||||
Index("source_filename_idx", "filename"),
|
||||
)
|
||||
|
||||
@property
|
||||
def vector_ids(self):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
chunks: list[extract.DataChunk] = []
|
||||
content = cast(str | None, self.content)
|
||||
if content:
|
||||
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
|
||||
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
summary, tags = summarizer.summarize(content)
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
mime_type = cast(str | None, self.mime_type)
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
|
||||
return chunks
|
||||
|
||||
def _make_chunk(
|
||||
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
|
||||
) -> Chunk:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
|
||||
images = [c for c in data.data if isinstance(c, Image.Image)]
|
||||
image_names = image_filenames(chunk_id, images)
|
||||
|
||||
modality = data.modality or cast(str, self.modality)
|
||||
chunk = Chunk(
|
||||
id=chunk_id,
|
||||
source=self,
|
||||
content=text or None,
|
||||
images=images,
|
||||
file_paths=image_names,
|
||||
collection_name=modality,
|
||||
embedding_model=collections.collection_model(modality, text, images),
|
||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||
)
|
||||
return chunk
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
||||
|
||||
@property
|
||||
def display_contents(self) -> str | dict | None:
|
||||
return {
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
@ -3,18 +3,14 @@ Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, Sequence, cast
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
UUID,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
@ -22,273 +18,24 @@ from sqlalchemy import (
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
event,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common import settings
|
||||
import memory.common.extract as extract
|
||||
import memory.common.collections as collections
|
||||
import memory.common.chunker as chunker
|
||||
import memory.common.summarizer as summarizer
|
||||
import memory.common.formatters.observation as observation
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def handle_duplicate_sha256(session, flush_context, instances):
|
||||
"""
|
||||
Event listener that efficiently checks for duplicate sha256 values before flush
|
||||
and removes items with duplicate sha256 from the session.
|
||||
|
||||
Uses a single query to identify all duplicates rather than querying for each item.
|
||||
"""
|
||||
# Find all SourceItem objects being added
|
||||
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
items = {}
|
||||
for item in new_items:
|
||||
try:
|
||||
if (sha256 := item.sha256) is None:
|
||||
continue
|
||||
|
||||
if sha256 in items:
|
||||
session.expunge(item)
|
||||
continue
|
||||
|
||||
items[sha256] = item
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
if not new_items:
|
||||
return
|
||||
|
||||
# Query database for existing items with these sha256 values in a single query
|
||||
existing_sha256s = set(
|
||||
row[0]
|
||||
for row in session.query(SourceItem.sha256).filter(
|
||||
SourceItem.sha256.in_(items.keys())
|
||||
)
|
||||
)
|
||||
|
||||
# Remove objects with duplicate sha256 values from the session
|
||||
for sha256 in existing_sha256s:
|
||||
if sha256 in items:
|
||||
session.expunge(items[sha256])
|
||||
|
||||
|
||||
def clean_filename(filename: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
||||
|
||||
|
||||
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
|
||||
for i, image in enumerate(images):
|
||||
if not image.filename: # type: ignore
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
|
||||
image.save(filename)
|
||||
image.filename = str(filename) # type: ignore
|
||||
|
||||
return [image.filename for image in images] # type: ignore
|
||||
|
||||
|
||||
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
|
||||
return [chunk] + [
|
||||
i
|
||||
for i in images
|
||||
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
|
||||
]
|
||||
|
||||
|
||||
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
final = {}
|
||||
for m in metadata:
|
||||
if tags := set(m.pop("tags", [])):
|
||||
final["tags"] = tags | final.get("tags", set())
|
||||
final |= m
|
||||
return final
|
||||
|
||||
|
||||
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
|
||||
if not content.strip():
|
||||
return []
|
||||
|
||||
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||
|
||||
summary, tags = summarizer.summarize(content)
|
||||
full_text: extract.DataChunk = extract.DataChunk(
|
||||
data=[content.strip(), *images], metadata={"tags": tags}
|
||||
)
|
||||
|
||||
chunks: list[extract.DataChunk] = [full_text]
|
||||
tokens = chunker.approx_token_count(content)
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks += [
|
||||
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
|
||||
for c in chunker.chunk_text(content)
|
||||
]
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
return [c for c in chunks if c.data]
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
"""Stores content chunks with their vector embeddings."""
|
||||
|
||||
__tablename__ = "chunk"
|
||||
|
||||
# The ID is also used as the vector ID in the vector database
|
||||
id = Column(
|
||||
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
|
||||
)
|
||||
source_id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
file_paths = Column(
|
||||
ARRAY(Text), nullable=True
|
||||
) # Path to content if stored as a file
|
||||
content = Column(Text) # Direct content storage
|
||||
embedding_model = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
vector: list[float] = []
|
||||
item_metadata: dict[str, Any] = {}
|
||||
images: list[Image.Image] = []
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
||||
Index("chunk_source_idx", "source_id"),
|
||||
)
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[extract.MulitmodalChunk]:
|
||||
chunks: list[extract.MulitmodalChunk] = []
|
||||
if cast(str | None, self.content):
|
||||
chunks = [cast(str, self.content)]
|
||||
if self.images:
|
||||
chunks += self.images
|
||||
elif cast(Sequence[str] | None, self.file_paths):
|
||||
chunks += [
|
||||
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
|
||||
]
|
||||
return chunks
|
||||
|
||||
@property
|
||||
def data(self) -> list[bytes | str | Image.Image]:
|
||||
if self.file_paths is None:
|
||||
return [cast(str, self.content)]
|
||||
|
||||
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
|
||||
files = [path for path in paths if path.exists()]
|
||||
|
||||
items = []
|
||||
for file_path in files:
|
||||
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||
if file_path.exists():
|
||||
items.append(Image.open(file_path))
|
||||
elif file_path.suffix == ".bin":
|
||||
items.append(file_path.read_bytes())
|
||||
else:
|
||||
items.append(file_path.read_text())
|
||||
return items
|
||||
|
||||
|
||||
class SourceItem(Base):
|
||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||
|
||||
__tablename__ = "source_item"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
modality = Column(Text, nullable=False)
|
||||
sha256 = Column(BYTEA, nullable=False, unique=True)
|
||||
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
size = Column(Integer)
|
||||
mime_type = Column(Text)
|
||||
|
||||
# Content is stored in the database if it's small enough and text
|
||||
content = Column(Text)
|
||||
# Otherwise the content is stored on disk
|
||||
filename = Column(Text, nullable=True)
|
||||
|
||||
# Chunks relationship
|
||||
embed_status = Column(Text, nullable=False, server_default="RAW")
|
||||
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
||||
|
||||
# Discriminator column for SQLAlchemy inheritance
|
||||
type = Column(String(50))
|
||||
|
||||
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
|
||||
|
||||
# Add table-level constraint and indexes
|
||||
__table_args__ = (
|
||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
Index("source_modality_idx", "modality"),
|
||||
Index("source_status_idx", "embed_status"),
|
||||
Index("source_tags_idx", "tags", postgresql_using="gin"),
|
||||
Index("source_filename_idx", "filename"),
|
||||
)
|
||||
|
||||
@property
|
||||
def vector_ids(self):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
chunks: list[extract.DataChunk] = []
|
||||
content = cast(str | None, self.content)
|
||||
if content:
|
||||
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
|
||||
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
summary, tags = summarizer.summarize(content)
|
||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||
|
||||
mime_type = cast(str | None, self.mime_type)
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
|
||||
return chunks
|
||||
|
||||
def _make_chunk(
|
||||
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
|
||||
) -> Chunk:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
|
||||
images = [c for c in data.data if isinstance(c, Image.Image)]
|
||||
image_names = image_filenames(chunk_id, images)
|
||||
|
||||
chunk = Chunk(
|
||||
id=chunk_id,
|
||||
source=self,
|
||||
content=text or None,
|
||||
images=images,
|
||||
file_paths=image_names,
|
||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||
)
|
||||
return chunk
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
||||
|
||||
@property
|
||||
def display_contents(self) -> str | None:
|
||||
return cast(str | None, self.content) or cast(str | None, self.filename)
|
||||
from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
Chunk,
|
||||
clean_filename,
|
||||
merge_metadata,
|
||||
chunk_mixed,
|
||||
)
|
||||
|
||||
|
||||
class MailMessage(SourceItem):
|
||||
@ -528,51 +275,6 @@ class Comic(SourceItem):
|
||||
return [extract.DataChunk(data=[image, description])]
|
||||
|
||||
|
||||
class Book(Base):
|
||||
"""Book-level metadata table"""
|
||||
|
||||
__tablename__ = "book"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
isbn = Column(Text, unique=True)
|
||||
title = Column(Text, nullable=False)
|
||||
author = Column(Text)
|
||||
publisher = Column(Text)
|
||||
published = Column(DateTime(timezone=True))
|
||||
language = Column(Text)
|
||||
edition = Column(Text)
|
||||
series = Column(Text)
|
||||
series_number = Column(Integer)
|
||||
total_pages = Column(Integer)
|
||||
file_path = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# Metadata from ebook parser
|
||||
book_metadata = Column(JSONB, name="metadata")
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("book_isbn_idx", "isbn"),
|
||||
Index("book_author_idx", "author"),
|
||||
Index("book_title_idx", "title"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
**super().as_payload(),
|
||||
"isbn": self.isbn,
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"publisher": self.publisher,
|
||||
"published": self.published,
|
||||
"language": self.language,
|
||||
"edition": self.edition,
|
||||
"series": self.series,
|
||||
"series_number": self.series_number,
|
||||
} | (cast(dict, self.book_metadata) or {})
|
||||
|
||||
|
||||
class BookSection(SourceItem):
|
||||
"""Individual sections/chapters of books"""
|
||||
|
||||
@ -797,58 +499,163 @@ class GithubItem(SourceItem):
|
||||
)
|
||||
|
||||
|
||||
class ArticleFeed(Base):
|
||||
__tablename__ = "article_feeds"
|
||||
class AgentObservation(SourceItem):
|
||||
"""
|
||||
Records observations made by AI agents about the user.
|
||||
This is the primary data model for the epistemic sparring partner.
|
||||
"""
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
url = Column(Text, nullable=False, unique=True)
|
||||
title = Column(Text)
|
||||
description = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
check_interval = Column(
|
||||
Integer, nullable=False, server_default="60", doc="Minutes between checks"
|
||||
__tablename__ = "agent_observation"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
session_id = Column(
|
||||
UUID(as_uuid=True)
|
||||
) # Groups observations from same conversation
|
||||
observation_type = Column(
|
||||
Text, nullable=False
|
||||
) # belief, preference, pattern, contradiction, behavior
|
||||
subject = Column(Text, nullable=False) # What/who the observation is about
|
||||
confidence = Column(Numeric(3, 2), nullable=False, default=0.8) # 0.0-1.0
|
||||
evidence = Column(JSONB) # Supporting context, quotes, etc.
|
||||
agent_model = Column(Text, nullable=False) # Which AI model made this observation
|
||||
|
||||
# Relationships
|
||||
contradictions_as_first = relationship(
|
||||
"ObservationContradiction",
|
||||
foreign_keys="ObservationContradiction.observation_1_id",
|
||||
back_populates="observation_1",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
contradictions_as_second = relationship(
|
||||
"ObservationContradiction",
|
||||
foreign_keys="ObservationContradiction.observation_2_id",
|
||||
back_populates="observation_2",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "agent_observation",
|
||||
}
|
||||
|
||||
__table_args__ = (
|
||||
Index("article_feeds_active_idx", "active", "last_checked_at"),
|
||||
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
|
||||
Index("agent_obs_session_idx", "session_id"),
|
||||
Index("agent_obs_type_idx", "observation_type"),
|
||||
Index("agent_obs_subject_idx", "subject"),
|
||||
Index("agent_obs_confidence_idx", "confidence"),
|
||||
Index("agent_obs_model_idx", "agent_model"),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not kwargs.get("modality"):
|
||||
kwargs["modality"] = "observation"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
class EmailAccount(Base):
|
||||
__tablename__ = "email_accounts"
|
||||
def as_payload(self) -> dict:
|
||||
payload = {
|
||||
**super().as_payload(),
|
||||
"observation_type": self.observation_type,
|
||||
"subject": self.subject,
|
||||
"confidence": float(cast(Any, self.confidence)),
|
||||
"evidence": self.evidence,
|
||||
"agent_model": self.agent_model,
|
||||
}
|
||||
if self.session_id is not None:
|
||||
payload["session_id"] = str(self.session_id)
|
||||
return payload
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
name = Column(Text, nullable=False)
|
||||
email_address = Column(Text, nullable=False, unique=True)
|
||||
imap_server = Column(Text, nullable=False)
|
||||
imap_port = Column(Integer, nullable=False, server_default="993")
|
||||
username = Column(Text, nullable=False)
|
||||
password = Column(Text, nullable=False)
|
||||
use_ssl = Column(Boolean, nullable=False, server_default="true")
|
||||
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
@property
|
||||
def all_contradictions(self):
|
||||
"""Get all contradictions involving this observation."""
|
||||
return self.contradictions_as_first + self.contradictions_as_second
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("email_accounts_address_idx", "email_address", unique=True),
|
||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
||||
@property
|
||||
def display_contents(self) -> dict:
|
||||
return {
|
||||
"subject": self.subject,
|
||||
"content": self.content,
|
||||
"observation_type": self.observation_type,
|
||||
"evidence": self.evidence,
|
||||
"confidence": self.confidence,
|
||||
"agent_model": self.agent_model,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
"""
|
||||
Generate multiple chunks for different embedding dimensions.
|
||||
Each chunk goes to a different Qdrant collection for specialized search.
|
||||
"""
|
||||
# 1. Semantic chunk - standard content representation
|
||||
semantic_text = observation.generate_semantic_text(
|
||||
cast(str, self.subject),
|
||||
cast(str, self.observation_type),
|
||||
cast(str, self.content),
|
||||
cast(observation.Evidence, self.evidence),
|
||||
)
|
||||
semantic_chunk = extract.DataChunk(
|
||||
data=[semantic_text],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||
modality="semantic",
|
||||
)
|
||||
|
||||
# 2. Temporal chunk - time-aware representation
|
||||
temporal_text = observation.generate_temporal_text(
|
||||
cast(str, self.subject),
|
||||
cast(str, self.content),
|
||||
cast(float, self.confidence),
|
||||
cast(datetime, self.inserted_at),
|
||||
)
|
||||
temporal_chunk = extract.DataChunk(
|
||||
data=[temporal_text],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
|
||||
modality="temporal",
|
||||
)
|
||||
|
||||
others = [
|
||||
self._make_chunk(
|
||||
extract.DataChunk(
|
||||
data=[i],
|
||||
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||
modality="semantic",
|
||||
)
|
||||
)
|
||||
for i in [
|
||||
self.content,
|
||||
self.subject,
|
||||
self.observation_type,
|
||||
self.evidence.get("quote", ""),
|
||||
]
|
||||
if i
|
||||
]
|
||||
|
||||
# TODO: Add more embedding dimensions here:
|
||||
# 3. Epistemic chunk - belief structure focused
|
||||
# epistemic_text = self._generate_epistemic_text()
|
||||
# chunks.append(extract.DataChunk(
|
||||
# data=[epistemic_text],
|
||||
# metadata={**base_metadata, "embedding_type": "epistemic"},
|
||||
# collection_name="observations_epistemic"
|
||||
# ))
|
||||
#
|
||||
# 4. Emotional chunk - emotional context focused
|
||||
# emotional_text = self._generate_emotional_text()
|
||||
# chunks.append(extract.DataChunk(
|
||||
# data=[emotional_text],
|
||||
# metadata={**base_metadata, "embedding_type": "emotional"},
|
||||
# collection_name="observations_emotional"
|
||||
# ))
|
||||
#
|
||||
# 5. Relational chunk - connection patterns focused
|
||||
# relational_text = self._generate_relational_text()
|
||||
# chunks.append(extract.DataChunk(
|
||||
# data=[relational_text],
|
||||
# metadata={**base_metadata, "embedding_type": "relational"},
|
||||
# collection_name="observations_relational"
|
||||
# ))
|
||||
|
||||
return [
|
||||
self._make_chunk(semantic_chunk),
|
||||
self._make_chunk(temporal_chunk),
|
||||
] + others
|
122
src/memory/common/db/models/sources.py
Normal file
122
src/memory/common/db/models/sources.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""
|
||||
Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Index,
|
||||
Integer,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class Book(Base):
|
||||
"""Book-level metadata table"""
|
||||
|
||||
__tablename__ = "book"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
isbn = Column(Text, unique=True)
|
||||
title = Column(Text, nullable=False)
|
||||
author = Column(Text)
|
||||
publisher = Column(Text)
|
||||
published = Column(DateTime(timezone=True))
|
||||
language = Column(Text)
|
||||
edition = Column(Text)
|
||||
series = Column(Text)
|
||||
series_number = Column(Integer)
|
||||
total_pages = Column(Integer)
|
||||
file_path = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
|
||||
# Metadata from ebook parser
|
||||
book_metadata = Column(JSONB, name="metadata")
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("book_isbn_idx", "isbn"),
|
||||
Index("book_author_idx", "author"),
|
||||
Index("book_title_idx", "title"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
**super().as_payload(),
|
||||
"isbn": self.isbn,
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"publisher": self.publisher,
|
||||
"published": self.published,
|
||||
"language": self.language,
|
||||
"edition": self.edition,
|
||||
"series": self.series,
|
||||
"series_number": self.series_number,
|
||||
} | (cast(dict, self.book_metadata) or {})
|
||||
|
||||
|
||||
class ArticleFeed(Base):
|
||||
__tablename__ = "article_feeds"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
url = Column(Text, nullable=False, unique=True)
|
||||
title = Column(Text)
|
||||
description = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
check_interval = Column(
|
||||
Integer, nullable=False, server_default="60", doc="Minutes between checks"
|
||||
)
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("article_feeds_active_idx", "active", "last_checked_at"),
|
||||
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
|
||||
class EmailAccount(Base):
|
||||
__tablename__ = "email_accounts"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
name = Column(Text, nullable=False)
|
||||
email_address = Column(Text, nullable=False, unique=True)
|
||||
imap_server = Column(Text, nullable=False)
|
||||
imap_port = Column(Integer, nullable=False, server_default="993")
|
||||
username = Column(Text, nullable=False)
|
||||
password = Column(Text, nullable=False)
|
||||
use_ssl = Column(Boolean, nullable=False, server_default="true")
|
||||
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("email_accounts_address_idx", "email_address", unique=True),
|
||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def embed_chunks(
|
||||
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
||||
chunks: list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
) -> list[Vector]:
|
||||
@ -24,52 +24,50 @@ def embed_chunks(
|
||||
vo = voyageai.Client() # type: ignore
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
chunks, # type: ignore
|
||||
chunks,
|
||||
model=model,
|
||||
input_type=input_type,
|
||||
).embeddings
|
||||
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
|
||||
|
||||
texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks]
|
||||
return cast(
|
||||
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
|
||||
)
|
||||
|
||||
|
||||
def break_chunk(
|
||||
chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS
|
||||
) -> list[extract.MulitmodalChunk]:
|
||||
result = []
|
||||
for c in chunk:
|
||||
if isinstance(c, str):
|
||||
result += chunk_text(c, chunk_size, OVERLAP_TOKENS)
|
||||
else:
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
chunks: list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
chunks = [
|
||||
c
|
||||
for text in texts
|
||||
if isinstance(text, str)
|
||||
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
|
||||
if c.strip()
|
||||
]
|
||||
if not chunks:
|
||||
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks]
|
||||
if not any(chunked_chunks):
|
||||
return []
|
||||
|
||||
try:
|
||||
return embed_chunks(chunks, model, input_type)
|
||||
except voyageai.error.InvalidRequestError as e: # type: ignore
|
||||
logger.error(f"Error embedding text: {e}")
|
||||
logger.debug(f"Text: {texts}")
|
||||
raise
|
||||
return embed_chunks(chunked_chunks, model, input_type)
|
||||
|
||||
|
||||
def embed_mixed(
|
||||
items: list[extract.MulitmodalChunk],
|
||||
items: list[list[extract.MulitmodalChunk]],
|
||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||
input_type: Literal["document", "query"] = "document",
|
||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||
) -> list[Vector]:
|
||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
||||
if isinstance(item, str):
|
||||
return [
|
||||
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
|
||||
]
|
||||
return [item]
|
||||
|
||||
chunks = [c for item in items for c in to_chunks(item)]
|
||||
return embed_chunks([chunks], model, input_type)
|
||||
chunked_chunks = [break_chunk(item, chunk_size) for item in items]
|
||||
return embed_chunks(chunked_chunks, model, input_type)
|
||||
|
||||
|
||||
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||
@ -87,6 +85,9 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||
|
||||
def embed_source_item(item: SourceItem) -> list[Chunk]:
|
||||
chunks = list(item.data_chunks())
|
||||
logger.error(
|
||||
f"Embedding source item: {item.id} - {[(c.embedding_model, c.collection_name, c.chunks) for c in chunks]}"
|
||||
)
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
|
@ -21,6 +21,7 @@ class DataChunk:
|
||||
data: Sequence[MulitmodalChunk]
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
mime_type: str = "text/plain"
|
||||
modality: str | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
87
src/memory/common/formatters/observation.py
Normal file
87
src/memory/common/formatters/observation.py
Normal file
@ -0,0 +1,87 @@
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class Evidence(TypedDict):
|
||||
quote: str
|
||||
context: str
|
||||
|
||||
|
||||
def generate_semantic_text(
|
||||
subject: str, observation_type: str, content: str, evidence: Evidence | None = None
|
||||
) -> str:
|
||||
"""Generate text optimized for semantic similarity search."""
|
||||
parts = [
|
||||
f"Subject: {subject}",
|
||||
f"Type: {observation_type}",
|
||||
f"Observation: {content}",
|
||||
]
|
||||
|
||||
if not evidence or not isinstance(evidence, dict):
|
||||
return " | ".join(parts)
|
||||
|
||||
if "quote" in evidence:
|
||||
parts.append(f"Quote: {evidence['quote']}")
|
||||
if "context" in evidence:
|
||||
parts.append(f"Context: {evidence['context']}")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
def generate_temporal_text(
|
||||
subject: str,
|
||||
content: str,
|
||||
confidence: float,
|
||||
created_at: datetime,
|
||||
) -> str:
|
||||
"""Generate text with temporal context for time-pattern search."""
|
||||
# Add temporal markers
|
||||
time_of_day = created_at.strftime("%H:%M")
|
||||
day_of_week = created_at.strftime("%A")
|
||||
|
||||
# Categorize time periods
|
||||
hour = created_at.hour
|
||||
if 5 <= hour < 12:
|
||||
time_period = "morning"
|
||||
elif 12 <= hour < 17:
|
||||
time_period = "afternoon"
|
||||
elif 17 <= hour < 22:
|
||||
time_period = "evening"
|
||||
else:
|
||||
time_period = "late_night"
|
||||
|
||||
parts = [
|
||||
f"Time: {time_of_day} on {day_of_week} ({time_period})",
|
||||
f"Subject: {subject}",
|
||||
f"Observation: {content}",
|
||||
]
|
||||
if confidence:
|
||||
parts.append(f"Confidence: {confidence}")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
# TODO: Add more embedding dimensions here:
|
||||
# 3. Epistemic chunk - belief structure focused
|
||||
# epistemic_text = self._generate_epistemic_text()
|
||||
# chunks.append(extract.DataChunk(
|
||||
# data=[epistemic_text],
|
||||
# metadata={**base_metadata, "embedding_type": "epistemic"},
|
||||
# collection_name="observations_epistemic"
|
||||
# ))
|
||||
#
|
||||
# 4. Emotional chunk - emotional context focused
|
||||
# emotional_text = self._generate_emotional_text()
|
||||
# chunks.append(extract.DataChunk(
|
||||
# data=[emotional_text],
|
||||
# metadata={**base_metadata, "embedding_type": "emotional"},
|
||||
# collection_name="observations_emotional"
|
||||
# ))
|
||||
#
|
||||
# 5. Relational chunk - connection patterns focused
|
||||
# relational_text = self._generate_relational_text()
|
||||
# chunks.append(extract.DataChunk(
|
||||
# data=[relational_text],
|
||||
# metadata={**base_metadata, "embedding_type": "relational"},
|
||||
# collection_name="observations_relational"
|
||||
# ))
|
@ -9,6 +9,8 @@ logger = logging.getLogger(__name__)
|
||||
TAGS_PROMPT = """
|
||||
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
|
||||
|
||||
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||
|
||||
Return your response as JSON with this format:
|
||||
{{
|
||||
"summary": "{summary}",
|
||||
@ -23,6 +25,8 @@ SUMMARY_PROMPT = """
|
||||
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
|
||||
Also provide 3-5 relevant tags that capture the main topics or themes.
|
||||
|
||||
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||
|
||||
Return your response as JSON with this format:
|
||||
{{
|
||||
"summary": "your summary here",
|
||||
|
@ -1,62 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Any
|
||||
from fastapi import UploadFile
|
||||
import httpx
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
SERVER = "http://localhost:8000"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
mcp = FastMCP("memory")
|
||||
|
||||
|
||||
async def make_request(
|
||||
path: str,
|
||||
method: str,
|
||||
data: dict | None = None,
|
||||
json: dict | None = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
) -> httpx.Response:
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.request(
|
||||
method,
|
||||
f"{SERVER}/{path}",
|
||||
data=data,
|
||||
json=json,
|
||||
files=files, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
|
||||
return await make_request(path, "POST", data=data)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search(
|
||||
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||
) -> list[dict[str, Any]]:
|
||||
logger.error(f"Searching for {query}")
|
||||
resp = await post_data(
|
||||
"search",
|
||||
{
|
||||
"query": query,
|
||||
"previews": previews,
|
||||
"modalities": modalities,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
return resp.json()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize and run the server
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--server", type=str)
|
||||
args = args.parse_args()
|
||||
|
||||
SERVER = args.server
|
||||
|
||||
mcp.run(transport=args.transport)
|
@ -6,13 +6,14 @@ the complete workflow: existence checking, content hashing, embedding generation
|
||||
vector storage, and result tracking.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
import hashlib
|
||||
import traceback
|
||||
import logging
|
||||
from typing import Any, Callable, Iterable, Sequence, cast
|
||||
|
||||
from memory.common import embedding, qdrant
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.db.models import SourceItem, Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -103,7 +104,19 @@ def embed_source_item(source_item: SourceItem) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
||||
def by_collection(chunks: Sequence[Chunk]) -> dict[str, dict[str, Any]]:
|
||||
collections: dict[str, dict[str, list[Any]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
for chunk in chunks:
|
||||
collection = collections[cast(str, chunk.collection_name)]
|
||||
collection["ids"].append(chunk.id)
|
||||
collection["vectors"].append(chunk.vector)
|
||||
collection["payloads"].append(chunk.item_metadata)
|
||||
return collections
|
||||
|
||||
|
||||
def push_to_qdrant(source_items: Sequence[SourceItem]):
|
||||
"""
|
||||
Push embeddings to Qdrant vector database.
|
||||
|
||||
@ -135,17 +148,16 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
||||
return
|
||||
|
||||
try:
|
||||
vector_ids = [str(chunk.id) for chunk in all_chunks]
|
||||
vectors = [chunk.vector for chunk in all_chunks]
|
||||
payloads = [chunk.item_metadata for chunk in all_chunks]
|
||||
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant.get_qdrant_client(),
|
||||
collection_name=collection_name,
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=payloads,
|
||||
)
|
||||
client = qdrant.get_qdrant_client()
|
||||
collections = by_collection(all_chunks)
|
||||
for collection_name, collection in collections.items():
|
||||
qdrant.upsert_vectors(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
ids=collection["ids"],
|
||||
vectors=collection["vectors"],
|
||||
payloads=collection["payloads"],
|
||||
)
|
||||
|
||||
for item in items_to_process:
|
||||
item.embed_status = "STORED" # type: ignore
|
||||
@ -222,7 +234,7 @@ def process_content_item(item: SourceItem, session) -> dict[str, Any]:
|
||||
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
||||
|
||||
try:
|
||||
push_to_qdrant([item], cast(str, item.modality))
|
||||
push_to_qdrant([item])
|
||||
status = "processed"
|
||||
item.embed_status = "STORED" # type: ignore
|
||||
logger.info(
|
||||
|
@ -185,7 +185,7 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
||||
f"Embedded section: {section.section_title} - {section.content[:100]}"
|
||||
)
|
||||
logger.info("Pushing to Qdrant")
|
||||
push_to_qdrant(all_sections, "book")
|
||||
push_to_qdrant(all_sections)
|
||||
logger.info("Committing session")
|
||||
|
||||
session.commit()
|
||||
|
@ -1,4 +1,3 @@
|
||||
from memory.common.db.models import SourceItem
|
||||
from sqlalchemy.orm import Session
|
||||
from unittest.mock import patch, Mock
|
||||
from typing import cast
|
||||
@ -6,17 +5,20 @@ import pytest
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
from memory.common import settings, chunker, extract
|
||||
from memory.common.db.models import (
|
||||
from memory.common.db.models.sources import Book
|
||||
from memory.common.db.models.source_items import (
|
||||
Chunk,
|
||||
clean_filename,
|
||||
image_filenames,
|
||||
add_pics,
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
BookSection,
|
||||
BlogPost,
|
||||
Book,
|
||||
)
|
||||
from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
image_filenames,
|
||||
add_pics,
|
||||
merge_metadata,
|
||||
clean_filename,
|
||||
)
|
||||
|
||||
|
||||
@ -585,288 +587,6 @@ def test_chunk_constraint_validation(
|
||||
assert chunk.id is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"modality,expected_modality",
|
||||
[
|
||||
(None, "email"), # Default case
|
||||
("custom", "custom"), # Override case
|
||||
],
|
||||
)
|
||||
def test_mail_message_modality(modality, expected_modality):
|
||||
"""Test MailMessage modality setting"""
|
||||
kwargs = {"sha256": b"test", "content": "test"}
|
||||
if modality is not None:
|
||||
kwargs["modality"] = modality
|
||||
|
||||
mail_message = MailMessage(**kwargs)
|
||||
# The __init__ method should set the correct modality
|
||||
assert hasattr(mail_message, "modality")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sender,folder,expected_path",
|
||||
[
|
||||
("user@example.com", "INBOX", "user_example_com/INBOX"),
|
||||
("user+tag@example.com", "Sent Items", "user_tag_example_com/Sent_Items"),
|
||||
("user@domain.co.uk", None, "user_domain_co_uk/INBOX"),
|
||||
("user@domain.co.uk", "", "user_domain_co_uk/INBOX"),
|
||||
],
|
||||
)
|
||||
def test_mail_message_attachments_path(sender, folder, expected_path):
|
||||
"""Test MailMessage.attachments_path property"""
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content="test", sender=sender, folder=folder
|
||||
)
|
||||
|
||||
result = mail_message.attachments_path
|
||||
assert str(result) == f"{settings.FILE_STORAGE_DIR}/emails/{expected_path}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename,expected",
|
||||
[
|
||||
("document.pdf", "document.pdf"),
|
||||
("file with spaces.txt", "file_with_spaces.txt"),
|
||||
("file@#$%^&*().doc", "file.doc"),
|
||||
("no-extension", "no_extension"),
|
||||
("multiple.dots.in.name.txt", "multiple_dots_in_name.txt"),
|
||||
],
|
||||
)
|
||||
def test_mail_message_safe_filename(tmp_path, filename, expected):
|
||||
"""Test MailMessage.safe_filename method"""
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content="test", sender="user@example.com", folder="INBOX"
|
||||
)
|
||||
|
||||
expected = settings.FILE_STORAGE_DIR / f"emails/user_example_com/INBOX/{expected}"
|
||||
assert mail_message.safe_filename(filename) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sent_at,expected_date",
|
||||
[
|
||||
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_mail_message_as_payload(sent_at, expected_date):
|
||||
"""Test MailMessage.as_payload method"""
|
||||
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test",
|
||||
content="test",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient1@example.com", "recipient2@example.com"],
|
||||
folder="INBOX",
|
||||
sent_at=sent_at,
|
||||
tags=["tag1", "tag2"],
|
||||
size=1024,
|
||||
)
|
||||
# Manually set id for testing
|
||||
object.__setattr__(mail_message, "id", 123)
|
||||
|
||||
payload = mail_message.as_payload()
|
||||
|
||||
expected = {
|
||||
"source_id": 123,
|
||||
"size": 1024,
|
||||
"message_id": "<test@example.com>",
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": ["recipient1@example.com", "recipient2@example.com"],
|
||||
"folder": "INBOX",
|
||||
"tags": [
|
||||
"tag1",
|
||||
"tag2",
|
||||
"sender@example.com",
|
||||
"recipient1@example.com",
|
||||
"recipient2@example.com",
|
||||
],
|
||||
"date": expected_date,
|
||||
}
|
||||
assert payload == expected
|
||||
|
||||
|
||||
def test_mail_message_parsed_content():
|
||||
"""Test MailMessage.parsed_content property with actual email parsing"""
|
||||
# Use a simple email format that the parser can handle
|
||||
email_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test Subject
|
||||
|
||||
Test Body Content"""
|
||||
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||
)
|
||||
|
||||
result = mail_message.parsed_content
|
||||
|
||||
# Just test that it returns a dict-like object
|
||||
assert isinstance(result, dict)
|
||||
assert "body" in result
|
||||
|
||||
|
||||
def test_mail_message_body_property():
|
||||
"""Test MailMessage.body property with actual email parsing"""
|
||||
email_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test Subject
|
||||
|
||||
Test Body Content"""
|
||||
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||
)
|
||||
|
||||
assert mail_message.body == "Test Body Content"
|
||||
|
||||
|
||||
def test_mail_message_display_contents():
|
||||
"""Test MailMessage.display_contents property with actual email parsing"""
|
||||
email_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test Subject
|
||||
|
||||
Test Body Content"""
|
||||
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||
)
|
||||
|
||||
expected = (
|
||||
"\nSubject: Test Subject\nFrom: \nTo: \nDate: \nBody: \nTest Body Content\n"
|
||||
)
|
||||
assert mail_message.display_contents == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"created_at,expected_date",
|
||||
[
|
||||
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_email_attachment_as_payload(created_at, expected_date):
|
||||
"""Test EmailAttachment.as_payload method"""
|
||||
attachment = EmailAttachment(
|
||||
sha256=b"test",
|
||||
filename="document.pdf",
|
||||
mime_type="application/pdf",
|
||||
size=1024,
|
||||
mail_message_id=123,
|
||||
created_at=created_at,
|
||||
tags=["pdf", "document"],
|
||||
)
|
||||
# Manually set id for testing
|
||||
object.__setattr__(attachment, "id", 456)
|
||||
|
||||
payload = attachment.as_payload()
|
||||
|
||||
expected = {
|
||||
"source_id": 456,
|
||||
"filename": "document.pdf",
|
||||
"content_type": "application/pdf",
|
||||
"size": 1024,
|
||||
"created_at": expected_date,
|
||||
"mail_message_id": 123,
|
||||
"tags": ["pdf", "document"],
|
||||
}
|
||||
assert payload == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"has_filename,content_source,expected_content",
|
||||
[
|
||||
(True, "file", b"test file content"),
|
||||
(False, "content", "attachment content"),
|
||||
],
|
||||
)
|
||||
@patch("memory.common.extract.extract_data_chunks")
|
||||
def test_email_attachment_data_chunks(
|
||||
mock_extract, has_filename, content_source, expected_content, tmp_path
|
||||
):
|
||||
"""Test EmailAttachment.data_chunks method"""
|
||||
from memory.common.extract import DataChunk
|
||||
|
||||
mock_extract.return_value = [
|
||||
DataChunk(data=["extracted text"], metadata={"source": content_source})
|
||||
]
|
||||
|
||||
if has_filename:
|
||||
# Create a test file
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_bytes(b"test file content")
|
||||
attachment = EmailAttachment(
|
||||
sha256=b"test",
|
||||
filename=str(test_file),
|
||||
mime_type="text/plain",
|
||||
mail_message_id=123,
|
||||
)
|
||||
else:
|
||||
attachment = EmailAttachment(
|
||||
sha256=b"test",
|
||||
content="attachment content",
|
||||
filename=None,
|
||||
mime_type="text/plain",
|
||||
mail_message_id=123,
|
||||
)
|
||||
|
||||
# Mock _make_chunk to return a simple chunk
|
||||
mock_chunk = Mock()
|
||||
with patch.object(attachment, "_make_chunk", return_value=mock_chunk) as mock_make:
|
||||
result = attachment.data_chunks({"extra": "metadata"})
|
||||
|
||||
# Verify the method calls
|
||||
mock_extract.assert_called_once_with("text/plain", expected_content)
|
||||
mock_make.assert_called_once_with(
|
||||
extract.DataChunk(data=["extracted text"], metadata={"source": content_source}),
|
||||
{"extra": "metadata"},
|
||||
)
|
||||
assert result == [mock_chunk]
|
||||
|
||||
|
||||
def test_email_attachment_cascade_delete(db_session: Session):
|
||||
"""Test that EmailAttachment is deleted when MailMessage is deleted"""
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test_email",
|
||||
content="test email",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
folder="INBOX",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.commit()
|
||||
|
||||
attachment = EmailAttachment(
|
||||
sha256=b"test_attachment",
|
||||
content="attachment content",
|
||||
mail_message=mail_message,
|
||||
filename="test.txt",
|
||||
mime_type="text/plain",
|
||||
size=100,
|
||||
modality="attachment", # Set modality explicitly
|
||||
)
|
||||
db_session.add(attachment)
|
||||
db_session.commit()
|
||||
|
||||
attachment_id = attachment.id
|
||||
|
||||
# Delete the mail message
|
||||
db_session.delete(mail_message)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the attachment was also deleted
|
||||
deleted_attachment = (
|
||||
db_session.query(EmailAttachment).filter_by(id=attachment_id).first()
|
||||
)
|
||||
assert deleted_attachment is None
|
||||
|
||||
|
||||
def test_subclass_deletion_cascades_to_source_item(db_session: Session):
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test_email_cascade",
|
||||
@ -926,173 +646,3 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session):
|
||||
# Verify both the MailMessage and SourceItem records are deleted
|
||||
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
|
||||
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pages,expected_chunks",
|
||||
[
|
||||
# No pages
|
||||
([], []),
|
||||
# Single page
|
||||
(["Page 1 content"], [("Page 1 content", {"type": "page"})]),
|
||||
# Multiple pages
|
||||
(
|
||||
["Page 1", "Page 2", "Page 3"],
|
||||
[
|
||||
(
|
||||
"Page 1\n\nPage 2\n\nPage 3",
|
||||
{"type": "section", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
("test", {"type": "summary", "tags": {"tag1", "tag2"}}),
|
||||
],
|
||||
),
|
||||
# Empty/whitespace pages filtered out
|
||||
(["", " ", "Page 3"], [("Page 3", {"type": "page"})]),
|
||||
# All empty - no chunks created
|
||||
(["", " ", " "], []),
|
||||
],
|
||||
)
|
||||
def test_book_section_data_chunks(pages, expected_chunks):
|
||||
"""Test BookSection.data_chunks with various page combinations"""
|
||||
content = "\n\n".join(pages).strip()
|
||||
book_section = BookSection(
|
||||
sha256=b"test_section",
|
||||
content=content,
|
||||
modality="book",
|
||||
book_id=1,
|
||||
start_page=10,
|
||||
end_page=10 + len(pages),
|
||||
pages=pages,
|
||||
book=Book(id=1, title="Test Book", author="Test Author"),
|
||||
)
|
||||
|
||||
chunks = book_section.data_chunks()
|
||||
expected = [
|
||||
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
|
||||
]
|
||||
assert [(c.content, c.item_metadata) for c in chunks] == expected
|
||||
for c in chunks:
|
||||
assert cast(list, c.file_paths) == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content,expected",
|
||||
[
|
||||
("", []),
|
||||
(
|
||||
"Short content",
|
||||
[
|
||||
extract.DataChunk(
|
||||
data=["Short content"], metadata={"tags": ["tag1", "tag2"]}
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
|
||||
[
|
||||
extract.DataChunk(
|
||||
data=[
|
||||
"This is a very long piece of content that should be chunked into multiple pieces when processed."
|
||||
],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["This is a very long piece of content that"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["should be chunked into multiple pieces when"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["processed."],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["test"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_blog_post_chunk_contents(content, expected, default_chunk_size):
|
||||
default_chunk_size(10)
|
||||
blog_post = BlogPost(
|
||||
sha256=b"test_blog",
|
||||
content=content,
|
||||
modality="blog",
|
||||
url="https://example.com/post",
|
||||
images=[],
|
||||
)
|
||||
|
||||
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
|
||||
assert blog_post._chunk_contents() == expected
|
||||
|
||||
|
||||
def test_blog_post_chunk_contents_with_images(tmp_path):
|
||||
"""Test BlogPost._chunk_contents with images"""
|
||||
# Create test image files
|
||||
img1_path = tmp_path / "img1.jpg"
|
||||
img2_path = tmp_path / "img2.jpg"
|
||||
for img_path in [img1_path, img2_path]:
|
||||
img = Image.new("RGB", (10, 10), color="red")
|
||||
img.save(img_path)
|
||||
|
||||
blog_post = BlogPost(
|
||||
sha256=b"test_blog",
|
||||
content="Content with images",
|
||||
modality="blog",
|
||||
url="https://example.com/post",
|
||||
images=[str(img1_path), str(img2_path)],
|
||||
)
|
||||
|
||||
result = blog_post._chunk_contents()
|
||||
result = [
|
||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
||||
for c in result
|
||||
]
|
||||
assert result == [
|
||||
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
|
||||
]
|
||||
|
||||
|
||||
def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chunk_size):
|
||||
default_chunk_size(10)
|
||||
img1_path = tmp_path / "img1.jpg"
|
||||
img2_path = tmp_path / "img2.jpg"
|
||||
for img_path in [img1_path, img2_path]:
|
||||
img = Image.new("RGB", (10, 10), color="red")
|
||||
img.save(img_path)
|
||||
|
||||
blog_post = BlogPost(
|
||||
sha256=b"test_blog",
|
||||
content=f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
||||
modality="blog",
|
||||
url="https://example.com/post",
|
||||
images=[str(img1_path), str(img2_path)],
|
||||
)
|
||||
|
||||
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
|
||||
result = blog_post._chunk_contents()
|
||||
|
||||
result = [
|
||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
||||
for c in result
|
||||
]
|
||||
assert result == [
|
||||
[
|
||||
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
||||
img1_path.as_posix(),
|
||||
img2_path.as_posix(),
|
||||
],
|
||||
[
|
||||
f"First picture is here: {img1_path.as_posix()}",
|
||||
img1_path.as_posix(),
|
||||
],
|
||||
[
|
||||
f"Second picture is here: {img2_path.as_posix()}",
|
||||
img2_path.as_posix(),
|
||||
],
|
||||
["test"],
|
||||
]
|
607
tests/memory/common/db/models/test_source_items.py
Normal file
607
tests/memory/common/db/models/test_source_items.py
Normal file
@ -0,0 +1,607 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from unittest.mock import patch, Mock
|
||||
from typing import cast
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from memory.common import settings, chunker, extract
|
||||
from memory.common.db.models.sources import Book
|
||||
from memory.common.db.models.source_items import (
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
BookSection,
|
||||
BlogPost,
|
||||
AgentObservation,
|
||||
)
|
||||
from memory.common.db.models.source_item import merge_metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_chunk_size():
|
||||
chunk_length = chunker.DEFAULT_CHUNK_TOKENS
|
||||
real_chunker = chunker.chunk_text
|
||||
|
||||
def chunk_text(text: str, max_tokens: int = 0):
|
||||
max_tokens = max_tokens or chunk_length
|
||||
return real_chunker(text, max_tokens=max_tokens)
|
||||
|
||||
def set_size(new_size: int):
|
||||
nonlocal chunk_length
|
||||
chunk_length = new_size
|
||||
|
||||
with patch.object(chunker, "chunk_text", chunk_text):
|
||||
yield set_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"modality,expected_modality",
|
||||
[
|
||||
(None, "email"), # Default case
|
||||
("custom", "custom"), # Override case
|
||||
],
|
||||
)
|
||||
def test_mail_message_modality(modality, expected_modality):
|
||||
"""Test MailMessage modality setting"""
|
||||
kwargs = {"sha256": b"test", "content": "test"}
|
||||
if modality is not None:
|
||||
kwargs["modality"] = modality
|
||||
|
||||
mail_message = MailMessage(**kwargs)
|
||||
# The __init__ method should set the correct modality
|
||||
assert hasattr(mail_message, "modality")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sender,folder,expected_path",
|
||||
[
|
||||
("user@example.com", "INBOX", "user_example_com/INBOX"),
|
||||
("user+tag@example.com", "Sent Items", "user_tag_example_com/Sent_Items"),
|
||||
("user@domain.co.uk", None, "user_domain_co_uk/INBOX"),
|
||||
("user@domain.co.uk", "", "user_domain_co_uk/INBOX"),
|
||||
],
|
||||
)
|
||||
def test_mail_message_attachments_path(sender, folder, expected_path):
|
||||
"""Test MailMessage.attachments_path property"""
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content="test", sender=sender, folder=folder
|
||||
)
|
||||
|
||||
result = mail_message.attachments_path
|
||||
assert str(result) == f"{settings.FILE_STORAGE_DIR}/emails/{expected_path}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename,expected",
|
||||
[
|
||||
("document.pdf", "document.pdf"),
|
||||
("file with spaces.txt", "file_with_spaces.txt"),
|
||||
("file@#$%^&*().doc", "file.doc"),
|
||||
("no-extension", "no_extension"),
|
||||
("multiple.dots.in.name.txt", "multiple_dots_in_name.txt"),
|
||||
],
|
||||
)
|
||||
def test_mail_message_safe_filename(tmp_path, filename, expected):
|
||||
"""Test MailMessage.safe_filename method"""
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content="test", sender="user@example.com", folder="INBOX"
|
||||
)
|
||||
|
||||
expected = settings.FILE_STORAGE_DIR / f"emails/user_example_com/INBOX/{expected}"
|
||||
assert mail_message.safe_filename(filename) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sent_at,expected_date",
|
||||
[
|
||||
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_mail_message_as_payload(sent_at, expected_date):
|
||||
"""Test MailMessage.as_payload method"""
|
||||
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test",
|
||||
content="test",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test Subject",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient1@example.com", "recipient2@example.com"],
|
||||
folder="INBOX",
|
||||
sent_at=sent_at,
|
||||
tags=["tag1", "tag2"],
|
||||
size=1024,
|
||||
)
|
||||
# Manually set id for testing
|
||||
object.__setattr__(mail_message, "id", 123)
|
||||
|
||||
payload = mail_message.as_payload()
|
||||
|
||||
expected = {
|
||||
"source_id": 123,
|
||||
"size": 1024,
|
||||
"message_id": "<test@example.com>",
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": ["recipient1@example.com", "recipient2@example.com"],
|
||||
"folder": "INBOX",
|
||||
"tags": [
|
||||
"tag1",
|
||||
"tag2",
|
||||
"sender@example.com",
|
||||
"recipient1@example.com",
|
||||
"recipient2@example.com",
|
||||
],
|
||||
"date": expected_date,
|
||||
}
|
||||
assert payload == expected
|
||||
|
||||
|
||||
def test_mail_message_parsed_content():
|
||||
"""Test MailMessage.parsed_content property with actual email parsing"""
|
||||
# Use a simple email format that the parser can handle
|
||||
email_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test Subject
|
||||
|
||||
Test Body Content"""
|
||||
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||
)
|
||||
|
||||
result = mail_message.parsed_content
|
||||
|
||||
# Just test that it returns a dict-like object
|
||||
assert isinstance(result, dict)
|
||||
assert "body" in result
|
||||
|
||||
|
||||
def test_mail_message_body_property():
|
||||
"""Test MailMessage.body property with actual email parsing"""
|
||||
email_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test Subject
|
||||
|
||||
Test Body Content"""
|
||||
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||
)
|
||||
|
||||
assert mail_message.body == "Test Body Content"
|
||||
|
||||
|
||||
def test_mail_message_display_contents():
|
||||
"""Test MailMessage.display_contents property with actual email parsing"""
|
||||
email_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test Subject
|
||||
|
||||
Test Body Content"""
|
||||
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||
)
|
||||
|
||||
expected = (
|
||||
"\nSubject: Test Subject\nFrom: \nTo: \nDate: \nBody: \nTest Body Content\n"
|
||||
)
|
||||
assert mail_message.display_contents == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"created_at,expected_date",
|
||||
[
|
||||
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_email_attachment_as_payload(created_at, expected_date):
|
||||
"""Test EmailAttachment.as_payload method"""
|
||||
attachment = EmailAttachment(
|
||||
sha256=b"test",
|
||||
filename="document.pdf",
|
||||
mime_type="application/pdf",
|
||||
size=1024,
|
||||
mail_message_id=123,
|
||||
created_at=created_at,
|
||||
tags=["pdf", "document"],
|
||||
)
|
||||
# Manually set id for testing
|
||||
object.__setattr__(attachment, "id", 456)
|
||||
|
||||
payload = attachment.as_payload()
|
||||
|
||||
expected = {
|
||||
"source_id": 456,
|
||||
"filename": "document.pdf",
|
||||
"content_type": "application/pdf",
|
||||
"size": 1024,
|
||||
"created_at": expected_date,
|
||||
"mail_message_id": 123,
|
||||
"tags": ["pdf", "document"],
|
||||
}
|
||||
assert payload == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"has_filename,content_source,expected_content",
|
||||
[
|
||||
(True, "file", b"test file content"),
|
||||
(False, "content", "attachment content"),
|
||||
],
|
||||
)
|
||||
@patch("memory.common.extract.extract_data_chunks")
|
||||
def test_email_attachment_data_chunks(
|
||||
mock_extract, has_filename, content_source, expected_content, tmp_path
|
||||
):
|
||||
"""Test EmailAttachment.data_chunks method"""
|
||||
from memory.common.extract import DataChunk
|
||||
|
||||
mock_extract.return_value = [
|
||||
DataChunk(data=["extracted text"], metadata={"source": content_source})
|
||||
]
|
||||
|
||||
if has_filename:
|
||||
# Create a test file
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_bytes(b"test file content")
|
||||
attachment = EmailAttachment(
|
||||
sha256=b"test",
|
||||
filename=str(test_file),
|
||||
mime_type="text/plain",
|
||||
mail_message_id=123,
|
||||
)
|
||||
else:
|
||||
attachment = EmailAttachment(
|
||||
sha256=b"test",
|
||||
content="attachment content",
|
||||
filename=None,
|
||||
mime_type="text/plain",
|
||||
mail_message_id=123,
|
||||
)
|
||||
|
||||
# Mock _make_chunk to return a simple chunk
|
||||
mock_chunk = Mock()
|
||||
with patch.object(attachment, "_make_chunk", return_value=mock_chunk) as mock_make:
|
||||
result = attachment.data_chunks({"extra": "metadata"})
|
||||
|
||||
# Verify the method calls
|
||||
mock_extract.assert_called_once_with("text/plain", expected_content)
|
||||
mock_make.assert_called_once_with(
|
||||
extract.DataChunk(data=["extracted text"], metadata={"source": content_source}),
|
||||
{"extra": "metadata"},
|
||||
)
|
||||
assert result == [mock_chunk]
|
||||
|
||||
|
||||
def test_email_attachment_cascade_delete(db_session: Session):
|
||||
"""Test that EmailAttachment is deleted when MailMessage is deleted"""
|
||||
mail_message = MailMessage(
|
||||
sha256=b"test_email",
|
||||
content="test email",
|
||||
message_id="<test@example.com>",
|
||||
subject="Test",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
folder="INBOX",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.commit()
|
||||
|
||||
attachment = EmailAttachment(
|
||||
sha256=b"test_attachment",
|
||||
content="attachment content",
|
||||
mail_message=mail_message,
|
||||
filename="test.txt",
|
||||
mime_type="text/plain",
|
||||
size=100,
|
||||
modality="attachment", # Set modality explicitly
|
||||
)
|
||||
db_session.add(attachment)
|
||||
db_session.commit()
|
||||
|
||||
attachment_id = attachment.id
|
||||
|
||||
# Delete the mail message
|
||||
db_session.delete(mail_message)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the attachment was also deleted
|
||||
deleted_attachment = (
|
||||
db_session.query(EmailAttachment).filter_by(id=attachment_id).first()
|
||||
)
|
||||
assert deleted_attachment is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pages,expected_chunks",
|
||||
[
|
||||
# No pages
|
||||
([], []),
|
||||
# Single page
|
||||
(["Page 1 content"], [("Page 1 content", {"type": "page"})]),
|
||||
# Multiple pages
|
||||
(
|
||||
["Page 1", "Page 2", "Page 3"],
|
||||
[
|
||||
(
|
||||
"Page 1\n\nPage 2\n\nPage 3",
|
||||
{"type": "section", "tags": {"tag1", "tag2"}},
|
||||
),
|
||||
("test", {"type": "summary", "tags": {"tag1", "tag2"}}),
|
||||
],
|
||||
),
|
||||
# Empty/whitespace pages filtered out
|
||||
(["", " ", "Page 3"], [("Page 3", {"type": "page"})]),
|
||||
# All empty - no chunks created
|
||||
(["", " ", " "], []),
|
||||
],
|
||||
)
|
||||
def test_book_section_data_chunks(pages, expected_chunks):
|
||||
"""Test BookSection.data_chunks with various page combinations"""
|
||||
content = "\n\n".join(pages).strip()
|
||||
book_section = BookSection(
|
||||
sha256=b"test_section",
|
||||
content=content,
|
||||
modality="book",
|
||||
book_id=1,
|
||||
start_page=10,
|
||||
end_page=10 + len(pages),
|
||||
pages=pages,
|
||||
book=Book(id=1, title="Test Book", author="Test Author"),
|
||||
)
|
||||
|
||||
chunks = book_section.data_chunks()
|
||||
expected = [
|
||||
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
|
||||
]
|
||||
assert [(c.content, c.item_metadata) for c in chunks] == expected
|
||||
for c in chunks:
|
||||
assert cast(list, c.file_paths) == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content,expected",
|
||||
[
|
||||
("", []),
|
||||
(
|
||||
"Short content",
|
||||
[
|
||||
extract.DataChunk(
|
||||
data=["Short content"], metadata={"tags": ["tag1", "tag2"]}
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
|
||||
[
|
||||
extract.DataChunk(
|
||||
data=[
|
||||
"This is a very long piece of content that should be chunked into multiple pieces when processed."
|
||||
],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["This is a very long piece of content that"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["should be chunked into multiple pieces when"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["processed."],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
extract.DataChunk(
|
||||
data=["test"],
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_blog_post_chunk_contents(content, expected, default_chunk_size):
|
||||
default_chunk_size(10)
|
||||
blog_post = BlogPost(
|
||||
sha256=b"test_blog",
|
||||
content=content,
|
||||
modality="blog",
|
||||
url="https://example.com/post",
|
||||
images=[],
|
||||
)
|
||||
|
||||
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
|
||||
assert blog_post._chunk_contents() == expected
|
||||
|
||||
|
||||
def test_blog_post_chunk_contents_with_images(tmp_path):
|
||||
"""Test BlogPost._chunk_contents with images"""
|
||||
# Create test image files
|
||||
img1_path = tmp_path / "img1.jpg"
|
||||
img2_path = tmp_path / "img2.jpg"
|
||||
for img_path in [img1_path, img2_path]:
|
||||
img = Image.new("RGB", (10, 10), color="red")
|
||||
img.save(img_path)
|
||||
|
||||
blog_post = BlogPost(
|
||||
sha256=b"test_blog",
|
||||
content="Content with images",
|
||||
modality="blog",
|
||||
url="https://example.com/post",
|
||||
images=[str(img1_path), str(img2_path)],
|
||||
)
|
||||
|
||||
result = blog_post._chunk_contents()
|
||||
result = [
|
||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
||||
for c in result
|
||||
]
|
||||
assert result == [
|
||||
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
|
||||
]
|
||||
|
||||
|
||||
def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chunk_size):
|
||||
default_chunk_size(10)
|
||||
img1_path = tmp_path / "img1.jpg"
|
||||
img2_path = tmp_path / "img2.jpg"
|
||||
for img_path in [img1_path, img2_path]:
|
||||
img = Image.new("RGB", (10, 10), color="red")
|
||||
img.save(img_path)
|
||||
|
||||
blog_post = BlogPost(
|
||||
sha256=b"test_blog",
|
||||
content=f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
||||
modality="blog",
|
||||
url="https://example.com/post",
|
||||
images=[str(img1_path), str(img2_path)],
|
||||
)
|
||||
|
||||
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
|
||||
result = blog_post._chunk_contents()
|
||||
|
||||
result = [
|
||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
||||
for c in result
|
||||
]
|
||||
assert result == [
|
||||
[
|
||||
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
||||
img1_path.as_posix(),
|
||||
img2_path.as_posix(),
|
||||
],
|
||||
[
|
||||
f"First picture is here: {img1_path.as_posix()}",
|
||||
img1_path.as_posix(),
|
||||
],
|
||||
[
|
||||
f"Second picture is here: {img2_path.as_posix()}",
|
||||
img2_path.as_posix(),
|
||||
],
|
||||
["test"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metadata,expected_semantic_metadata,expected_temporal_metadata,observation_tags",
|
||||
[
|
||||
(
|
||||
{},
|
||||
{"embedding_type": "semantic"},
|
||||
{"embedding_type": "temporal"},
|
||||
[],
|
||||
),
|
||||
(
|
||||
{"extra_key": "extra_value"},
|
||||
{"extra_key": "extra_value", "embedding_type": "semantic"},
|
||||
{"extra_key": "extra_value", "embedding_type": "temporal"},
|
||||
[],
|
||||
),
|
||||
(
|
||||
{"tags": ["existing_tag"], "source": "test"},
|
||||
{"tags": {"existing_tag"}, "source": "test", "embedding_type": "semantic"},
|
||||
{"tags": {"existing_tag"}, "source": "test", "embedding_type": "temporal"},
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_agent_observation_data_chunks(
|
||||
metadata, expected_semantic_metadata, expected_temporal_metadata, observation_tags
|
||||
):
|
||||
"""Test AgentObservation.data_chunks generates correct chunks with proper metadata"""
|
||||
observation = AgentObservation(
|
||||
sha256=b"test_obs",
|
||||
content="User prefers Python over JavaScript",
|
||||
subject="programming preferences",
|
||||
observation_type="preference",
|
||||
confidence=0.9,
|
||||
evidence={
|
||||
"quote": "I really like Python",
|
||||
"context": "discussion about languages",
|
||||
},
|
||||
agent_model="claude-3.5-sonnet",
|
||||
session_id=uuid.uuid4(),
|
||||
tags=observation_tags,
|
||||
)
|
||||
# Set inserted_at using object.__setattr__ to bypass SQLAlchemy restrictions
|
||||
object.__setattr__(observation, "inserted_at", datetime(2023, 1, 1, 12, 0, 0))
|
||||
|
||||
result = observation.data_chunks(metadata)
|
||||
|
||||
# Verify chunks
|
||||
assert len(result) == 2
|
||||
|
||||
semantic_chunk = result[0]
|
||||
expected_semantic_text = "Subject: programming preferences | Type: preference | Observation: User prefers Python over JavaScript | Quote: I really like Python | Context: discussion about languages"
|
||||
assert semantic_chunk.data == [expected_semantic_text]
|
||||
assert semantic_chunk.metadata == expected_semantic_metadata
|
||||
assert semantic_chunk.collection_name == "semantic"
|
||||
|
||||
temporal_chunk = result[1]
|
||||
expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript | Confidence: 0.9"
|
||||
assert temporal_chunk.data == [expected_temporal_text]
|
||||
assert temporal_chunk.metadata == expected_temporal_metadata
|
||||
assert temporal_chunk.collection_name == "temporal"
|
||||
|
||||
|
||||
def test_agent_observation_data_chunks_with_none_values():
|
||||
"""Test AgentObservation.data_chunks handles None values correctly"""
|
||||
observation = AgentObservation(
|
||||
sha256=b"test_obs",
|
||||
content="Content",
|
||||
subject="subject",
|
||||
observation_type="belief",
|
||||
confidence=0.7,
|
||||
evidence=None,
|
||||
agent_model="gpt-4",
|
||||
session_id=None,
|
||||
)
|
||||
object.__setattr__(observation, "inserted_at", datetime(2023, 2, 15, 9, 30, 0))
|
||||
|
||||
result = observation.data_chunks()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].collection_name == "semantic"
|
||||
assert result[1].collection_name == "temporal"
|
||||
|
||||
# Verify content with None evidence
|
||||
semantic_text = "Subject: subject | Type: belief | Observation: Content"
|
||||
assert result[0].data == [semantic_text]
|
||||
|
||||
temporal_text = "Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content | Confidence: 0.7"
|
||||
assert result[1].data == [temporal_text]
|
||||
|
||||
|
||||
def test_agent_observation_data_chunks_merge_metadata_behavior():
|
||||
"""Test that merge_metadata works correctly in data_chunks"""
|
||||
observation = AgentObservation(
|
||||
sha256=b"test",
|
||||
content="test",
|
||||
subject="test",
|
||||
observation_type="test",
|
||||
confidence=0.8,
|
||||
evidence={},
|
||||
agent_model="test",
|
||||
tags=["base_tag"], # Set base tags so they appear in both chunks
|
||||
)
|
||||
object.__setattr__(observation, "inserted_at", datetime.now())
|
||||
|
||||
# Test that metadata merging preserves original values and adds new ones
|
||||
input_metadata = {"existing": "value", "tags": ["tag1"]}
|
||||
result = observation.data_chunks(input_metadata)
|
||||
|
||||
semantic_metadata = result[0].metadata
|
||||
temporal_metadata = result[1].metadata
|
||||
|
||||
# Both should have the existing metadata plus embedding_type
|
||||
assert semantic_metadata["existing"] == "value"
|
||||
assert semantic_metadata["tags"] == {"tag1"} # Merged tags
|
||||
assert semantic_metadata["embedding_type"] == "semantic"
|
||||
|
||||
assert temporal_metadata["existing"] == "value"
|
||||
assert temporal_metadata["tags"] == {"tag1"} # Merged tags
|
||||
assert temporal_metadata["embedding_type"] == "temporal"
|
220
tests/memory/common/formatters/test_observation.py
Normal file
220
tests/memory/common/formatters/test_observation.py
Normal file
@ -0,0 +1,220 @@
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from memory.common.formatters.observation import (
|
||||
Evidence,
|
||||
generate_semantic_text,
|
||||
generate_temporal_text,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_semantic_text_basic_functionality():
|
||||
evidence: Evidence = {"quote": "test quote", "context": "test context"}
|
||||
result = generate_semantic_text(
|
||||
subject="test_subject",
|
||||
observation_type="test_type",
|
||||
content="test_content",
|
||||
evidence=evidence,
|
||||
)
|
||||
assert (
|
||||
result
|
||||
== "Subject: test_subject | Type: test_type | Observation: test_content | Quote: test quote | Context: test context"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"evidence,expected_suffix",
|
||||
[
|
||||
({"quote": "test quote"}, " | Quote: test quote"),
|
||||
({"context": "test context"}, " | Context: test context"),
|
||||
({}, ""),
|
||||
],
|
||||
)
|
||||
def test_generate_semantic_text_partial_evidence(
|
||||
evidence: dict[str, str], expected_suffix: str
|
||||
):
|
||||
result = generate_semantic_text(
|
||||
subject="subject",
|
||||
observation_type="type",
|
||||
content="content",
|
||||
evidence=evidence, # type: ignore
|
||||
)
|
||||
expected = f"Subject: subject | Type: type | Observation: content{expected_suffix}"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_generate_semantic_text_none_evidence():
|
||||
result = generate_semantic_text(
|
||||
subject="subject",
|
||||
observation_type="type",
|
||||
content="content",
|
||||
evidence=None, # type: ignore
|
||||
)
|
||||
assert result == "Subject: subject | Type: type | Observation: content"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_evidence",
|
||||
[
|
||||
"string",
|
||||
123,
|
||||
["list"],
|
||||
True,
|
||||
],
|
||||
)
|
||||
def test_generate_semantic_text_invalid_evidence_types(invalid_evidence: Any):
|
||||
result = generate_semantic_text(
|
||||
subject="subject",
|
||||
observation_type="type",
|
||||
content="content",
|
||||
evidence=invalid_evidence, # type: ignore
|
||||
)
|
||||
assert result == "Subject: subject | Type: type | Observation: content"
|
||||
|
||||
|
||||
def test_generate_semantic_text_empty_strings():
|
||||
evidence = {"quote": "", "context": ""}
|
||||
result = generate_semantic_text(
|
||||
subject="",
|
||||
observation_type="",
|
||||
content="",
|
||||
evidence=evidence, # type: ignore
|
||||
)
|
||||
assert result == "Subject: | Type: | Observation: | Quote: | Context: "
|
||||
|
||||
|
||||
def test_generate_semantic_text_special_characters():
|
||||
evidence: Evidence = {
|
||||
"quote": "Quote with | pipe and | symbols",
|
||||
"context": "Context with special chars: @#$%",
|
||||
}
|
||||
result = generate_semantic_text(
|
||||
subject="Subject with | pipe",
|
||||
observation_type="Type with | pipe",
|
||||
content="Content with | pipe",
|
||||
evidence=evidence,
|
||||
)
|
||||
expected = "Subject: Subject with | pipe | Type: Type with | pipe | Observation: Content with | pipe | Quote: Quote with | pipe and | symbols | Context: Context with special chars: @#$%"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hour,expected_period",
|
||||
[
|
||||
(5, "morning"),
|
||||
(6, "morning"),
|
||||
(11, "morning"),
|
||||
(12, "afternoon"),
|
||||
(13, "afternoon"),
|
||||
(16, "afternoon"),
|
||||
(17, "evening"),
|
||||
(18, "evening"),
|
||||
(21, "evening"),
|
||||
(22, "late_night"),
|
||||
(23, "late_night"),
|
||||
(0, "late_night"),
|
||||
(1, "late_night"),
|
||||
(4, "late_night"),
|
||||
],
|
||||
)
|
||||
def test_generate_temporal_text_time_periods(hour: int, expected_period: str):
|
||||
test_date = datetime(2024, 1, 15, hour, 30) # Monday
|
||||
result = generate_temporal_text(
|
||||
subject="test_subject",
|
||||
content="test_content",
|
||||
confidence=0.8,
|
||||
created_at=test_date,
|
||||
)
|
||||
time_str = test_date.strftime("%H:%M")
|
||||
expected = f"Time: {time_str} on Monday ({expected_period}) | Subject: test_subject | Observation: test_content | Confidence: 0.8"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weekday,day_name",
|
||||
[
|
||||
(0, "Monday"),
|
||||
(1, "Tuesday"),
|
||||
(2, "Wednesday"),
|
||||
(3, "Thursday"),
|
||||
(4, "Friday"),
|
||||
(5, "Saturday"),
|
||||
(6, "Sunday"),
|
||||
],
|
||||
)
|
||||
def test_generate_temporal_text_days_of_week(weekday: int, day_name: str):
|
||||
test_date = datetime(2024, 1, 15 + weekday, 10, 30)
|
||||
result = generate_temporal_text(
|
||||
subject="subject", content="content", confidence=0.5, created_at=test_date
|
||||
)
|
||||
assert f"on {day_name}" in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("confidence", [0.0, 0.1, 0.5, 0.99, 1.0])
|
||||
def test_generate_temporal_text_confidence_values(confidence: float):
|
||||
test_date = datetime(2024, 1, 15, 10, 30)
|
||||
result = generate_temporal_text(
|
||||
subject="subject",
|
||||
content="content",
|
||||
confidence=confidence,
|
||||
created_at=test_date,
|
||||
)
|
||||
assert f"Confidence: {confidence}" in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_date,expected_period",
|
||||
[
|
||||
(datetime(2024, 1, 15, 5, 0), "morning"), # Start of morning
|
||||
(datetime(2024, 1, 15, 11, 59), "morning"), # End of morning
|
||||
(datetime(2024, 1, 15, 12, 0), "afternoon"), # Start of afternoon
|
||||
(datetime(2024, 1, 15, 16, 59), "afternoon"), # End of afternoon
|
||||
(datetime(2024, 1, 15, 17, 0), "evening"), # Start of evening
|
||||
(datetime(2024, 1, 15, 21, 59), "evening"), # End of evening
|
||||
(datetime(2024, 1, 15, 22, 0), "late_night"), # Start of late_night
|
||||
(datetime(2024, 1, 15, 4, 59), "late_night"), # End of late_night
|
||||
],
|
||||
)
|
||||
def test_generate_temporal_text_boundary_cases(
|
||||
test_date: datetime, expected_period: str
|
||||
):
|
||||
result = generate_temporal_text(
|
||||
subject="subject", content="content", confidence=0.8, created_at=test_date
|
||||
)
|
||||
assert f"({expected_period})" in result
|
||||
|
||||
|
||||
def test_generate_temporal_text_complete_format():
|
||||
test_date = datetime(2024, 3, 22, 14, 45) # Friday afternoon
|
||||
result = generate_temporal_text(
|
||||
subject="Important observation",
|
||||
content="User showed strong preference for X",
|
||||
confidence=0.95,
|
||||
created_at=test_date,
|
||||
)
|
||||
expected = "Time: 14:45 on Friday (afternoon) | Subject: Important observation | Observation: User showed strong preference for X | Confidence: 0.95"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_generate_temporal_text_empty_strings():
|
||||
test_date = datetime(2024, 1, 15, 10, 30)
|
||||
result = generate_temporal_text(
|
||||
subject="", content="", confidence=0.0, created_at=test_date
|
||||
)
|
||||
assert (
|
||||
result
|
||||
== "Time: 10:30 on Monday (morning) | Subject: | Observation: | Confidence: 0.0"
|
||||
)
|
||||
|
||||
|
||||
def test_generate_temporal_text_special_characters():
|
||||
test_date = datetime(2024, 1, 15, 15, 20)
|
||||
result = generate_temporal_text(
|
||||
subject="Subject with | pipe",
|
||||
content="Content with | pipe and @#$ symbols",
|
||||
confidence=0.75,
|
||||
created_at=test_date,
|
||||
)
|
||||
expected = "Time: 15:20 on Monday (afternoon) | Subject: Subject with | pipe | Observation: Content with | pipe and @#$ symbols | Confidence: 0.75"
|
||||
assert result == expected
|
@ -18,6 +18,7 @@ from memory.workers.tasks.content_processing import (
|
||||
process_content_item,
|
||||
push_to_qdrant,
|
||||
safe_task_execution,
|
||||
by_collection,
|
||||
)
|
||||
|
||||
|
||||
@ -71,6 +72,7 @@ def mock_chunk():
|
||||
chunk.id = "00000000-0000-0000-0000-000000000001"
|
||||
chunk.vector = [0.1] * 1024
|
||||
chunk.item_metadata = {"source_id": 1, "tags": ["test"]}
|
||||
chunk.collection_name = "mail"
|
||||
return chunk
|
||||
|
||||
|
||||
@ -242,17 +244,19 @@ def test_push_to_qdrant_success(qdrant):
|
||||
mock_chunk1.id = "00000000-0000-0000-0000-000000000001"
|
||||
mock_chunk1.vector = [0.1] * 1024
|
||||
mock_chunk1.item_metadata = {"source_id": 1, "tags": ["test"]}
|
||||
mock_chunk1.collection_name = "mail"
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.id = "00000000-0000-0000-0000-000000000002"
|
||||
mock_chunk2.vector = [0.2] * 1024
|
||||
mock_chunk2.item_metadata = {"source_id": 2, "tags": ["test"]}
|
||||
mock_chunk2.collection_name = "mail"
|
||||
|
||||
# Assign chunks directly (bypassing SQLAlchemy relationship)
|
||||
item1.chunks = [mock_chunk1]
|
||||
item2.chunks = [mock_chunk2]
|
||||
|
||||
push_to_qdrant([item1, item2], "mail")
|
||||
push_to_qdrant([item1, item2])
|
||||
|
||||
assert str(item1.embed_status) == "STORED"
|
||||
assert str(item2.embed_status) == "STORED"
|
||||
@ -294,6 +298,7 @@ def test_push_to_qdrant_no_processing(
|
||||
mock_chunk.id = f"00000000-0000-0000-0000-00000000000{suffix}"
|
||||
mock_chunk.vector = [0.1] * 1024
|
||||
mock_chunk.item_metadata = {"source_id": int(suffix), "tags": ["test"]}
|
||||
mock_chunk.collection_name = "mail"
|
||||
item.chunks = [mock_chunk]
|
||||
else:
|
||||
item.chunks = []
|
||||
@ -302,7 +307,7 @@ def test_push_to_qdrant_no_processing(
|
||||
item1 = create_item("1", item1_status, item1_has_chunks)
|
||||
item2 = create_item("2", item2_status, item2_has_chunks)
|
||||
|
||||
push_to_qdrant([item1, item2], "mail")
|
||||
push_to_qdrant([item1, item2])
|
||||
|
||||
assert str(item1.embed_status) == expected_item1_status
|
||||
assert str(item2.embed_status) == expected_item2_status
|
||||
@ -317,7 +322,7 @@ def test_push_to_qdrant_exception(sample_mail_message, mock_chunk):
|
||||
side_effect=Exception("Qdrant error"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Qdrant error"):
|
||||
push_to_qdrant([sample_mail_message], "mail")
|
||||
push_to_qdrant([sample_mail_message])
|
||||
|
||||
assert str(sample_mail_message.embed_status) == "FAILED"
|
||||
|
||||
@ -418,6 +423,196 @@ def test_create_task_result_no_title():
|
||||
assert result["chunks_count"] == 0
|
||||
|
||||
|
||||
def test_by_collection_empty_chunks():
|
||||
result = by_collection([])
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_by_collection_single_chunk():
|
||||
chunk = Chunk(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
content="test content",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
item_metadata={"source_id": 1, "tags": ["test"]},
|
||||
collection_name="test_collection",
|
||||
)
|
||||
|
||||
result = by_collection([chunk])
|
||||
|
||||
assert len(result) == 1
|
||||
assert "test_collection" in result
|
||||
assert result["test_collection"]["ids"] == ["00000000-0000-0000-0000-000000000001"]
|
||||
assert result["test_collection"]["vectors"] == [[0.1, 0.2, 0.3]]
|
||||
assert result["test_collection"]["payloads"] == [{"source_id": 1, "tags": ["test"]}]
|
||||
|
||||
|
||||
def test_by_collection_multiple_chunks_same_collection():
|
||||
chunks = [
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
content="test content 1",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1, 0.2],
|
||||
item_metadata={"source_id": 1},
|
||||
collection_name="collection_a",
|
||||
),
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000002",
|
||||
content="test content 2",
|
||||
embedding_model="test-model",
|
||||
vector=[0.3, 0.4],
|
||||
item_metadata={"source_id": 2},
|
||||
collection_name="collection_a",
|
||||
),
|
||||
]
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "collection_a" in result
|
||||
assert result["collection_a"]["ids"] == [
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
"00000000-0000-0000-0000-000000000002",
|
||||
]
|
||||
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.3, 0.4]]
|
||||
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 2}]
|
||||
|
||||
|
||||
def test_by_collection_multiple_chunks_different_collections():
|
||||
chunks = [
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
content="test content 1",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1, 0.2],
|
||||
item_metadata={"source_id": 1},
|
||||
collection_name="collection_a",
|
||||
),
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000002",
|
||||
content="test content 2",
|
||||
embedding_model="test-model",
|
||||
vector=[0.3, 0.4],
|
||||
item_metadata={"source_id": 2},
|
||||
collection_name="collection_b",
|
||||
),
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000003",
|
||||
content="test content 3",
|
||||
embedding_model="test-model",
|
||||
vector=[0.5, 0.6],
|
||||
item_metadata={"source_id": 3},
|
||||
collection_name="collection_a",
|
||||
),
|
||||
]
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "collection_a" in result
|
||||
assert "collection_b" in result
|
||||
|
||||
# Check collection_a
|
||||
assert result["collection_a"]["ids"] == [
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
"00000000-0000-0000-0000-000000000003",
|
||||
]
|
||||
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.5, 0.6]]
|
||||
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 3}]
|
||||
|
||||
# Check collection_b
|
||||
assert result["collection_b"]["ids"] == ["00000000-0000-0000-0000-000000000002"]
|
||||
assert result["collection_b"]["vectors"] == [[0.3, 0.4]]
|
||||
assert result["collection_b"]["payloads"] == [{"source_id": 2}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"collection_names,expected_collections",
|
||||
[
|
||||
(["col1", "col1", "col1"], 1),
|
||||
(["col1", "col2", "col3"], 3),
|
||||
(["col1", "col2", "col1", "col2"], 2),
|
||||
(["single"], 1),
|
||||
],
|
||||
)
|
||||
def test_by_collection_various_groupings(collection_names, expected_collections):
|
||||
chunks = [
|
||||
Chunk(
|
||||
id=f"00000000-0000-0000-0000-00000000000{i}",
|
||||
content=f"test content {i}",
|
||||
embedding_model="test-model",
|
||||
vector=[float(i)],
|
||||
item_metadata={"index": i},
|
||||
collection_name=collection_name,
|
||||
)
|
||||
for i, collection_name in enumerate(collection_names, 1)
|
||||
]
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == expected_collections
|
||||
# Verify all chunks are accounted for
|
||||
total_chunks = sum(len(coll["ids"]) for coll in result.values())
|
||||
assert total_chunks == len(chunks)
|
||||
|
||||
|
||||
def test_by_collection_with_none_values():
|
||||
chunks = [
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000001",
|
||||
content="test content",
|
||||
embedding_model="test-model",
|
||||
vector=None, # None vector
|
||||
item_metadata=None, # None metadata
|
||||
collection_name="test_collection",
|
||||
),
|
||||
Chunk(
|
||||
id="00000000-0000-0000-0000-000000000002",
|
||||
content="test content 2",
|
||||
embedding_model="test-model",
|
||||
vector=[0.1, 0.2],
|
||||
item_metadata={"key": "value"},
|
||||
collection_name="test_collection",
|
||||
),
|
||||
]
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "test_collection" in result
|
||||
assert result["test_collection"]["ids"] == [
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
"00000000-0000-0000-0000-000000000002",
|
||||
]
|
||||
assert result["test_collection"]["vectors"] == [None, [0.1, 0.2]]
|
||||
assert result["test_collection"]["payloads"] == [None, {"key": "value"}]
|
||||
|
||||
|
||||
def test_by_collection_preserves_order():
|
||||
chunks = []
|
||||
for i in range(5):
|
||||
chunks.append(
|
||||
Chunk(
|
||||
id=f"00000000-0000-0000-0000-00000000000{i}",
|
||||
content=f"test content {i}",
|
||||
embedding_model="test-model",
|
||||
vector=[float(i)],
|
||||
item_metadata={"order": i},
|
||||
collection_name="ordered_collection",
|
||||
)
|
||||
)
|
||||
|
||||
result = by_collection(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result["ordered_collection"]["ids"] == [
|
||||
f"00000000-0000-0000-0000-00000000000{i}" for i in range(5)
|
||||
]
|
||||
assert result["ordered_collection"]["vectors"] == [[float(i)] for i in range(5)]
|
||||
assert result["ordered_collection"]["payloads"] == [{"order": i} for i in range(5)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_return,qdrant_error,expected_status,expected_embed_status",
|
||||
[
|
||||
@ -459,6 +654,7 @@ def test_process_content_item(
|
||||
embedding_model="test-model",
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"source_id": 1, "tags": ["test"]},
|
||||
collection_name="mail",
|
||||
)
|
||||
mock_chunks = [real_chunk]
|
||||
else: # empty
|
||||
|
@ -2,6 +2,7 @@
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch, call
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
@ -350,6 +351,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
id=chunk_id,
|
||||
source=item,
|
||||
content=f"Test chunk content {i}",
|
||||
collection_name=item.modality,
|
||||
embedding_model="test-model",
|
||||
)
|
||||
for i, chunk_id in enumerate(chunk_ids)
|
||||
@ -358,7 +360,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
db_session.commit()
|
||||
|
||||
# Add vectors to Qdrant
|
||||
modality = "mail" if item_type == "MailMessage" else "blog"
|
||||
modality = cast(str, item.modality)
|
||||
qd.ensure_collection_exists(qdrant, modality, 1024)
|
||||
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
|
||||
|
||||
@ -375,6 +377,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content 1",
|
||||
embedding_model="test-model",
|
||||
collection_name=modality,
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
@ -382,6 +385,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content 2",
|
||||
embedding_model="test-model",
|
||||
collection_name=modality,
|
||||
vector=[0.2] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
@ -449,6 +453,7 @@ def test_reingest_item_no_chunks(db_session, qdrant):
|
||||
id=str(uuid.uuid4()),
|
||||
content="New chunk content",
|
||||
embedding_model="test-model",
|
||||
collection_name=item.modality,
|
||||
vector=[0.1] * 1024,
|
||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||
),
|
||||
@ -538,6 +543,7 @@ def test_reingest_empty_source_items_success(db_session, item_type):
|
||||
source=item_with_chunks,
|
||||
content="Test chunk content",
|
||||
embedding_model="test-model",
|
||||
collection_name=item_with_chunks.modality,
|
||||
)
|
||||
db_session.add(chunk)
|
||||
db_session.commit()
|
||||
|
Loading…
x
Reference in New Issue
Block a user