mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-30 22:56:08 +02:00
Compare commits
3 Commits
e505f9b53c
...
b10a1fb130
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b10a1fb130 | ||
![]() |
1dd93929c1 | ||
![]() |
004bd39987 |
@ -12,6 +12,32 @@ from alembic import context
|
|||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.models import Base
|
from memory.common.db.models import Base
|
||||||
|
|
||||||
|
# Import all models to ensure they're registered with Base.metadata
|
||||||
|
from memory.common.db.models import (
|
||||||
|
SourceItem,
|
||||||
|
Chunk,
|
||||||
|
MailMessage,
|
||||||
|
EmailAttachment,
|
||||||
|
ChatMessage,
|
||||||
|
BlogPost,
|
||||||
|
Comic,
|
||||||
|
BookSection,
|
||||||
|
ForumPost,
|
||||||
|
GithubItem,
|
||||||
|
GitCommit,
|
||||||
|
Photo,
|
||||||
|
MiscDoc,
|
||||||
|
AgentObservation,
|
||||||
|
ObservationContradiction,
|
||||||
|
ReactionPattern,
|
||||||
|
ObservationPattern,
|
||||||
|
BeliefCluster,
|
||||||
|
ConversationMetrics,
|
||||||
|
Book,
|
||||||
|
ArticleFeed,
|
||||||
|
EmailAccount,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# this is the Alembic Config object
|
# this is the Alembic Config object
|
||||||
config = context.config
|
config = context.config
|
||||||
|
279
db/migrations/versions/20250531_154947_add_observation_models.py
Normal file
279
db/migrations/versions/20250531_154947_add_observation_models.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
"""Add observation models
|
||||||
|
|
||||||
|
Revision ID: 6554eb260176
|
||||||
|
Revises: 2524646f56f6
|
||||||
|
Create Date: 2025-05-31 15:49:47.579256
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "6554eb260176"
|
||||||
|
down_revision: Union[str, None] = "2524646f56f6"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"belief_cluster",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("cluster_name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("core_beliefs", sa.ARRAY(sa.Text()), nullable=False),
|
||||||
|
sa.Column("peripheral_beliefs", sa.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"internal_consistency", sa.Numeric(precision=3, scale=2), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"last_updated",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"cluster_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"belief_cluster_consistency_idx",
|
||||||
|
"belief_cluster",
|
||||||
|
["internal_consistency"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"belief_cluster_name_idx", "belief_cluster", ["cluster_name"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"conversation_metrics",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("session_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("depth_score", sa.Numeric(precision=3, scale=2), nullable=True),
|
||||||
|
sa.Column("breakthrough_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"challenge_acceptance", sa.Numeric(precision=3, scale=2), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("new_insights", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("user_engagement", sa.Numeric(precision=3, scale=2), nullable=True),
|
||||||
|
sa.Column("duration_minutes", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("observation_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("contradiction_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"metrics_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"conv_metrics_breakthrough_idx",
|
||||||
|
"conversation_metrics",
|
||||||
|
["breakthrough_count"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"conv_metrics_depth_idx", "conversation_metrics", ["depth_score"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"conv_metrics_session_idx", "conversation_metrics", ["session_id"], unique=True
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"observation_pattern",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("pattern_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=False),
|
||||||
|
sa.Column("supporting_observations", sa.ARRAY(sa.BigInteger()), nullable=False),
|
||||||
|
sa.Column("exceptions", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||||
|
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||||
|
sa.Column("validity_start", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("validity_end", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"pattern_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"obs_pattern_confidence_idx",
|
||||||
|
"observation_pattern",
|
||||||
|
["confidence"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"obs_pattern_type_idx", "observation_pattern", ["pattern_type"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"obs_pattern_validity_idx",
|
||||||
|
"observation_pattern",
|
||||||
|
["validity_start", "validity_end"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"reaction_pattern",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("trigger_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("reaction_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("frequency", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"first_observed",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"last_observed",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("example_observations", sa.ARRAY(sa.BigInteger()), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"reaction_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"reaction_frequency_idx", "reaction_pattern", ["frequency"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"reaction_trigger_idx", "reaction_pattern", ["trigger_type"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"reaction_type_idx", "reaction_pattern", ["reaction_type"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"agent_observation",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("session_id", sa.UUID(), nullable=True),
|
||||||
|
sa.Column("observation_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column("subject", sa.Text(), nullable=False),
|
||||||
|
sa.Column("confidence", sa.Numeric(precision=3, scale=2), nullable=False),
|
||||||
|
sa.Column("evidence", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column("agent_model", sa.Text(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"agent_obs_confidence_idx", "agent_observation", ["confidence"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"agent_obs_model_idx", "agent_observation", ["agent_model"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"agent_obs_session_idx", "agent_observation", ["session_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"agent_obs_subject_idx", "agent_observation", ["subject"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"agent_obs_type_idx", "agent_observation", ["observation_type"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"observation_contradiction",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("observation_1_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("observation_2_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("contradiction_type", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"detected_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("detection_method", sa.Text(), nullable=False),
|
||||||
|
sa.Column("resolution", sa.Text(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"observation_metadata",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["observation_1_id"], ["agent_observation.id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["observation_2_id"], ["agent_observation.id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"obs_contra_method_idx",
|
||||||
|
"observation_contradiction",
|
||||||
|
["detection_method"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"obs_contra_obs1_idx",
|
||||||
|
"observation_contradiction",
|
||||||
|
["observation_1_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"obs_contra_obs2_idx",
|
||||||
|
"observation_contradiction",
|
||||||
|
["observation_2_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"obs_contra_type_idx",
|
||||||
|
"observation_contradiction",
|
||||||
|
["contradiction_type"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.add_column("chunk", sa.Column("collection_name", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("chunk", "collection_name")
|
||||||
|
op.drop_index("obs_contra_type_idx", table_name="observation_contradiction")
|
||||||
|
op.drop_index("obs_contra_obs2_idx", table_name="observation_contradiction")
|
||||||
|
op.drop_index("obs_contra_obs1_idx", table_name="observation_contradiction")
|
||||||
|
op.drop_index("obs_contra_method_idx", table_name="observation_contradiction")
|
||||||
|
op.drop_table("observation_contradiction")
|
||||||
|
op.drop_index("agent_obs_type_idx", table_name="agent_observation")
|
||||||
|
op.drop_index("agent_obs_subject_idx", table_name="agent_observation")
|
||||||
|
op.drop_index("agent_obs_session_idx", table_name="agent_observation")
|
||||||
|
op.drop_index("agent_obs_model_idx", table_name="agent_observation")
|
||||||
|
op.drop_index("agent_obs_confidence_idx", table_name="agent_observation")
|
||||||
|
op.drop_table("agent_observation")
|
||||||
|
op.drop_index("reaction_type_idx", table_name="reaction_pattern")
|
||||||
|
op.drop_index("reaction_trigger_idx", table_name="reaction_pattern")
|
||||||
|
op.drop_index("reaction_frequency_idx", table_name="reaction_pattern")
|
||||||
|
op.drop_table("reaction_pattern")
|
||||||
|
op.drop_index("obs_pattern_validity_idx", table_name="observation_pattern")
|
||||||
|
op.drop_index("obs_pattern_type_idx", table_name="observation_pattern")
|
||||||
|
op.drop_index("obs_pattern_confidence_idx", table_name="observation_pattern")
|
||||||
|
op.drop_table("observation_pattern")
|
||||||
|
op.drop_index("conv_metrics_session_idx", table_name="conversation_metrics")
|
||||||
|
op.drop_index("conv_metrics_depth_idx", table_name="conversation_metrics")
|
||||||
|
op.drop_index("conv_metrics_breakthrough_idx", table_name="conversation_metrics")
|
||||||
|
op.drop_table("conversation_metrics")
|
||||||
|
op.drop_index("belief_cluster_name_idx", table_name="belief_cluster")
|
||||||
|
op.drop_index("belief_cluster_consistency_idx", table_name="belief_cluster")
|
||||||
|
op.drop_table("belief_cluster")
|
@ -3,3 +3,4 @@ uvicorn==0.29.0
|
|||||||
python-jose==3.3.0
|
python-jose==3.3.0
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
sqladmin
|
sqladmin
|
||||||
|
mcp==1.9.2
|
@ -5,4 +5,6 @@ alembic==1.13.1
|
|||||||
dotenv==0.9.9
|
dotenv==0.9.9
|
||||||
voyageai==0.3.2
|
voyageai==0.3.2
|
||||||
qdrant-client==1.9.0
|
qdrant-client==1.9.0
|
||||||
anthropic==0.18.1
|
anthropic==0.18.1
|
||||||
|
|
||||||
|
bm25s[full]==0.2.13
|
654
src/memory/api/MCP/tools.py
Normal file
654
src/memory/api/MCP/tools.py
Normal file
@ -0,0 +1,654 @@
|
|||||||
|
"""
|
||||||
|
MCP tools for the epistemic sparring partner system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from hashlib import sha256
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from memory.common.db.models.source_item import SourceItem
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
from sqlalchemy import func, cast as sql_cast, Text
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
|
||||||
|
from memory.common import extract
|
||||||
|
from memory.common.db.models import AgentObservation
|
||||||
|
from memory.api.search import search, SearchFilters
|
||||||
|
from memory.common.formatters import observation
|
||||||
|
from memory.workers.tasks.content_processing import process_content_item
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create MCP server instance
|
||||||
|
mcp = FastMCP("memory", stateless=True)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_tags() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique tags used across the entire knowledge base.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool retrieves all tags that have been used in the system, both from
|
||||||
|
AI observations (created with 'observe') and other content. Use it to
|
||||||
|
understand the tag taxonomy, ensure consistency, or discover related topics.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Before creating new observations, to use consistent tag naming
|
||||||
|
- To explore what topics/contexts have been tracked
|
||||||
|
- To build tag filters for search operations
|
||||||
|
- To understand the user's areas of interest
|
||||||
|
- For tag autocomplete or suggestion features
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of all unique tags in the system. Tags follow patterns like:
|
||||||
|
- Topics: "machine-learning", "functional-programming"
|
||||||
|
- Projects: "project:website-redesign"
|
||||||
|
- Contexts: "context:work", "context:late-night"
|
||||||
|
- Domains: "domain:finance"
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Get all tags to ensure consistency
|
||||||
|
tags = await get_all_tags()
|
||||||
|
# Returns: ["ai-safety", "context:work", "functional-programming",
|
||||||
|
# "machine-learning", "project:thesis", ...]
|
||||||
|
|
||||||
|
# Use to check if a topic has been discussed before
|
||||||
|
if "quantum-computing" in tags:
|
||||||
|
# Search for related observations
|
||||||
|
observations = await search_observations(
|
||||||
|
query="quantum computing",
|
||||||
|
tags=["quantum-computing"]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||||
|
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_subjects() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique subjects from observations about the user.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool retrieves all subject identifiers that have been used in
|
||||||
|
observations (created with 'observe'). Subjects are the consistent
|
||||||
|
identifiers for what observations are about. Use this to understand
|
||||||
|
what aspects of the user have been tracked and ensure consistency.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Before creating new observations, to use existing subject names
|
||||||
|
- To discover what aspects of the user have been observed
|
||||||
|
- To build subject filters for targeted searches
|
||||||
|
- To ensure consistent naming across observations
|
||||||
|
- To get an overview of the user model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of all unique subjects. Common patterns include:
|
||||||
|
- "programming_style", "programming_philosophy"
|
||||||
|
- "work_habits", "work_schedule"
|
||||||
|
- "ai_beliefs", "ai_safety_beliefs"
|
||||||
|
- "learning_preferences"
|
||||||
|
- "communication_style"
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Get all subjects to ensure consistency
|
||||||
|
subjects = await get_all_subjects()
|
||||||
|
# Returns: ["ai_safety_beliefs", "architecture_preferences",
|
||||||
|
# "programming_philosophy", "work_schedule", ...]
|
||||||
|
|
||||||
|
# Use to check what we know about the user
|
||||||
|
if "programming_style" in subjects:
|
||||||
|
# Get all programming-related observations
|
||||||
|
observations = await search_observations(
|
||||||
|
query="programming",
|
||||||
|
subject="programming_style"
|
||||||
|
)
|
||||||
|
|
||||||
|
Best practices:
|
||||||
|
- Always check existing subjects before creating new ones
|
||||||
|
- Use snake_case for consistency
|
||||||
|
- Be specific but not too granular
|
||||||
|
- Group related observations under same subject
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_observation_types() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique observation types that have been used.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool retrieves the distinct observation types that have been recorded
|
||||||
|
in the system. While the standard types are predefined (belief, preference,
|
||||||
|
behavior, contradiction, general), this shows what's actually been used.
|
||||||
|
Helpful for understanding the distribution of observation types.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- To see what types of observations have been made
|
||||||
|
- To understand the balance of different observation types
|
||||||
|
- To check if all standard types are being utilized
|
||||||
|
- For analytics or reporting on observation patterns
|
||||||
|
|
||||||
|
Standard types:
|
||||||
|
- "belief": Opinions or beliefs the user holds
|
||||||
|
- "preference": Things they prefer or favor
|
||||||
|
- "behavior": Patterns in how they act or work
|
||||||
|
- "contradiction": Noted inconsistencies
|
||||||
|
- "general": Observations that don't fit other categories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of observation types that have actually been used in the system.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Check what types of observations exist
|
||||||
|
types = await get_all_observation_types()
|
||||||
|
# Returns: ["behavior", "belief", "contradiction", "preference"]
|
||||||
|
|
||||||
|
# Use to analyze observation distribution
|
||||||
|
for obs_type in types:
|
||||||
|
observations = await search_observations(
|
||||||
|
query="",
|
||||||
|
observation_types=[obs_type],
|
||||||
|
limit=100
|
||||||
|
)
|
||||||
|
print(f"{obs_type}: {len(observations)} observations")
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
{
|
||||||
|
r.observation_type
|
||||||
|
for r in session.query(AgentObservation.observation_type).distinct()
|
||||||
|
if r.observation_type is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_knowledge_base(
|
||||||
|
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Search through the user's stored knowledge and content.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool searches the user's personal knowledge base - a collection of
|
||||||
|
their saved content including emails, documents, blog posts, books, and
|
||||||
|
more. Use this alongside 'search_observations' to build a complete picture:
|
||||||
|
- search_knowledge_base: Finds user's actual content and information
|
||||||
|
- search_observations: Finds AI-generated insights about the user
|
||||||
|
Together they enable deeply personalized, context-aware assistance.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- User asks about something they've read/written/received
|
||||||
|
- You need to find specific content the user has saved
|
||||||
|
- User references a document, email, or article
|
||||||
|
- To provide quotes or information from user's sources
|
||||||
|
- To understand context from user's past communications
|
||||||
|
- When user says "that article about..." or similar references
|
||||||
|
|
||||||
|
How it works:
|
||||||
|
Uses hybrid search combining semantic understanding with keyword matching.
|
||||||
|
This means it finds content based on meaning AND specific terms, giving
|
||||||
|
you the best of both approaches. Results are ranked by relevance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Natural language search query. Be descriptive about what you're
|
||||||
|
looking for. The search understands meaning but also values exact terms.
|
||||||
|
Examples:
|
||||||
|
- "email about project deadline from last week"
|
||||||
|
- "functional programming articles comparing Haskell and Scala"
|
||||||
|
- "that blog post about AI safety and alignment"
|
||||||
|
- "recipe for chocolate cake Sarah sent me"
|
||||||
|
Pro tip: Include both concepts and specific keywords for best results.
|
||||||
|
|
||||||
|
previews: Whether to include content snippets in results.
|
||||||
|
- True: Returns preview text and image previews (useful for quick scanning)
|
||||||
|
- False: Returns just metadata (faster, less data)
|
||||||
|
Default is False.
|
||||||
|
|
||||||
|
modalities: Types of content to search. Leave empty to search all.
|
||||||
|
Available types:
|
||||||
|
- 'email': Email messages
|
||||||
|
- 'blog': Blog posts and articles
|
||||||
|
- 'book': Book sections and ebooks
|
||||||
|
- 'forum': Forum posts (e.g., LessWrong, Reddit)
|
||||||
|
- 'observation': AI observations (use search_observations instead)
|
||||||
|
- 'photo': Images with extracted text
|
||||||
|
- 'comic': Comics and graphic content
|
||||||
|
- 'webpage': General web pages
|
||||||
|
Examples:
|
||||||
|
- ["email"] - only emails
|
||||||
|
- ["blog", "forum"] - articles and forum posts
|
||||||
|
- [] - search everything
|
||||||
|
|
||||||
|
limit: Maximum results to return (1-100). Default 10.
|
||||||
|
Increase for comprehensive searches, decrease for quick lookups.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results ranked by relevance, each containing:
|
||||||
|
- id: Unique identifier for the source item
|
||||||
|
- score: Relevance score (0-1, higher is better)
|
||||||
|
- chunks: Matching content segments with metadata
|
||||||
|
- content: Full details including:
|
||||||
|
- For emails: sender, recipient, subject, date
|
||||||
|
- For blogs: author, title, url, publish date
|
||||||
|
- For books: title, author, chapter info
|
||||||
|
- Type-specific fields for each modality
|
||||||
|
- filename: Path to file if content is stored on disk
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Find specific email
|
||||||
|
results = await search_knowledge_base(
|
||||||
|
query="Sarah deadline project proposal next Friday",
|
||||||
|
modalities=["email"],
|
||||||
|
previews=True,
|
||||||
|
limit=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search for technical articles
|
||||||
|
results = await search_knowledge_base(
|
||||||
|
query="functional programming monads category theory",
|
||||||
|
modalities=["blog", "book"],
|
||||||
|
limit=20
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find everything about a topic
|
||||||
|
results = await search_knowledge_base(
|
||||||
|
query="machine learning deployment kubernetes docker",
|
||||||
|
previews=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quick lookup of a remembered document
|
||||||
|
results = await search_knowledge_base(
|
||||||
|
query="tax forms 2023 accountant recommendations",
|
||||||
|
modalities=["email"],
|
||||||
|
limit=3
|
||||||
|
)
|
||||||
|
|
||||||
|
Best practices:
|
||||||
|
- Include context in queries ("email from Sarah" vs just "Sarah")
|
||||||
|
- Use modalities to filter when you know the content type
|
||||||
|
- Enable previews when you need to verify content before using
|
||||||
|
- Combine with search_observations for complete context
|
||||||
|
- Higher scores (>0.7) indicate strong matches
|
||||||
|
- If no results, try broader queries or different phrasing
|
||||||
|
"""
|
||||||
|
logger.info(f"MCP search for: {query}")
|
||||||
|
|
||||||
|
upload_data = extract.extract_text(query)
|
||||||
|
results = await search(
|
||||||
|
upload_data,
|
||||||
|
previews=previews,
|
||||||
|
modalities=modalities,
|
||||||
|
limit=limit,
|
||||||
|
min_text_score=0.3,
|
||||||
|
min_multimodal_score=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert SearchResult objects to dictionaries for MCP
|
||||||
|
return [result.model_dump() for result in results]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def observe(
|
||||||
|
content: str,
|
||||||
|
subject: str,
|
||||||
|
observation_type: str = "general",
|
||||||
|
confidence: float = 0.8,
|
||||||
|
evidence: dict | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
agent_model: str = "unknown",
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Record an observation about the user to build long-term understanding.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool is part of a memory system designed to help AI agents build a
|
||||||
|
deep, persistent understanding of users over time. Use it to record any
|
||||||
|
notable information about the user's preferences, beliefs, behaviors, or
|
||||||
|
characteristics. These observations accumulate to create a comprehensive
|
||||||
|
model of the user that improves future interactions.
|
||||||
|
|
||||||
|
Quick Reference:
|
||||||
|
# Most common patterns:
|
||||||
|
observe(content="User prefers X over Y because...", subject="preferences", observation_type="preference")
|
||||||
|
observe(content="User always/often does X when Y", subject="work_habits", observation_type="behavior")
|
||||||
|
observe(content="User believes/thinks X about Y", subject="beliefs_on_topic", observation_type="belief")
|
||||||
|
observe(content="User said X but previously said Y", subject="topic", observation_type="contradiction")
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- User expresses a preference or opinion
|
||||||
|
- You notice a behavioral pattern
|
||||||
|
- User reveals information about their work/life/interests
|
||||||
|
- You spot a contradiction with previous statements
|
||||||
|
- Any insight that would help understand the user better in future
|
||||||
|
|
||||||
|
Important: Be an active observer. Don't wait to be asked - proactively record
|
||||||
|
observations throughout conversations to build understanding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The observation itself. Be specific and detailed. Write complete
|
||||||
|
thoughts that will make sense when read months later without context.
|
||||||
|
Bad: "Likes FP"
|
||||||
|
Good: "User strongly prefers functional programming paradigms, especially
|
||||||
|
pure functions and immutability, considering them more maintainable"
|
||||||
|
|
||||||
|
subject: A consistent identifier for what this observation is about. Use
|
||||||
|
snake_case and be consistent across observations to enable tracking.
|
||||||
|
Examples:
|
||||||
|
- "programming_style" (not "coding" or "development")
|
||||||
|
- "work_habits" (not "productivity" or "work_patterns")
|
||||||
|
- "ai_safety_beliefs" (not "AI" or "artificial_intelligence")
|
||||||
|
|
||||||
|
observation_type: Categorize the observation:
|
||||||
|
- "belief": An opinion or belief the user holds
|
||||||
|
- "preference": Something they prefer or favor
|
||||||
|
- "behavior": A pattern in how they act or work
|
||||||
|
- "contradiction": An inconsistency with previous observations
|
||||||
|
- "general": Doesn't fit other categories
|
||||||
|
|
||||||
|
confidence: How certain you are (0.0-1.0):
|
||||||
|
- 1.0: User explicitly stated this
|
||||||
|
- 0.9: Strongly implied or demonstrated repeatedly
|
||||||
|
- 0.8: Inferred with high confidence (default)
|
||||||
|
- 0.7: Probable but with some uncertainty
|
||||||
|
- 0.6 or below: Speculative, use sparingly
|
||||||
|
|
||||||
|
evidence: Supporting context as a dict. Include relevant details:
|
||||||
|
- "quote": Exact words from the user
|
||||||
|
- "context": What prompted this observation
|
||||||
|
- "timestamp": When this was observed
|
||||||
|
- "related_to": Connection to other topics
|
||||||
|
Example: {
|
||||||
|
"quote": "I always refactor to pure functions",
|
||||||
|
"context": "Discussing code review practices"
|
||||||
|
}
|
||||||
|
|
||||||
|
tags: Categorization labels. Use lowercase with hyphens. Common patterns:
|
||||||
|
- Topics: "machine-learning", "web-development", "philosophy"
|
||||||
|
- Projects: "project:website-redesign", "project:thesis"
|
||||||
|
- Contexts: "context:work", "context:personal", "context:late-night"
|
||||||
|
- Domains: "domain:finance", "domain:healthcare"
|
||||||
|
|
||||||
|
session_id: UUID string to group observations from the same conversation.
|
||||||
|
Generate one UUID per conversation and reuse it for all observations
|
||||||
|
in that conversation. Format: "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
|
||||||
|
agent_model: Which AI model made this observation (e.g., "claude-3-opus",
|
||||||
|
"gpt-4", "claude-3.5-sonnet"). Helps track observation quality.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with created observation details:
|
||||||
|
- id: Unique identifier for reference
|
||||||
|
- created_at: Timestamp of creation
|
||||||
|
- subject: The subject as stored
|
||||||
|
- observation_type: The type as stored
|
||||||
|
- confidence: The confidence score
|
||||||
|
- tags: List of applied tags
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# After user mentions their coding philosophy
|
||||||
|
await observe(
|
||||||
|
content="User believes strongly in functional programming principles, "
|
||||||
|
"particularly avoiding mutable state which they call 'the root "
|
||||||
|
"of all evil'. They prioritize code purity over performance.",
|
||||||
|
subject="programming_philosophy",
|
||||||
|
observation_type="belief",
|
||||||
|
confidence=0.95,
|
||||||
|
evidence={
|
||||||
|
"quote": "State is the root of all evil in programming",
|
||||||
|
"context": "Discussing why they chose Haskell for their project"
|
||||||
|
},
|
||||||
|
tags=["programming", "functional-programming", "philosophy"],
|
||||||
|
session_id="550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
agent_model="claude-3-opus"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Noticing a work pattern
|
||||||
|
await observe(
|
||||||
|
content="User frequently works on complex problems late at night, "
|
||||||
|
"typically between 11pm and 3am, claiming better focus",
|
||||||
|
subject="work_schedule",
|
||||||
|
observation_type="behavior",
|
||||||
|
confidence=0.85,
|
||||||
|
evidence={
|
||||||
|
"context": "Mentioned across multiple conversations over 2 weeks"
|
||||||
|
},
|
||||||
|
tags=["behavior", "work-habits", "productivity", "context:late-night"],
|
||||||
|
agent_model="claude-3-opus"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recording a contradiction
|
||||||
|
await observe(
|
||||||
|
content="User now advocates for microservices architecture, but "
|
||||||
|
"previously argued strongly for monoliths in similar contexts",
|
||||||
|
subject="architecture_preferences",
|
||||||
|
observation_type="contradiction",
|
||||||
|
confidence=0.9,
|
||||||
|
evidence={
|
||||||
|
"quote": "Microservices are definitely the way to go",
|
||||||
|
"context": "Designing a new system similar to one from 3 months ago"
|
||||||
|
},
|
||||||
|
tags=["architecture", "contradiction", "software-design"],
|
||||||
|
agent_model="gpt-4"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
# Create the observation
|
||||||
|
observation = AgentObservation(
|
||||||
|
content=content,
|
||||||
|
subject=subject,
|
||||||
|
observation_type=observation_type,
|
||||||
|
confidence=confidence,
|
||||||
|
evidence=evidence,
|
||||||
|
tags=tags or [],
|
||||||
|
session_id=uuid.UUID(session_id) if session_id else None,
|
||||||
|
agent_model=agent_model,
|
||||||
|
size=len(content),
|
||||||
|
mime_type="text/plain",
|
||||||
|
sha256=sha256(f"{content}{subject}{observation_type}".encode("utf-8")).digest(),
|
||||||
|
inserted_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with make_session() as session:
|
||||||
|
process_content_item(observation, session)
|
||||||
|
|
||||||
|
if not observation.id:
|
||||||
|
raise ValueError("Observation not created")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Observation created: {observation.id}, {observation.inserted_at}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"id": observation.id,
|
||||||
|
"created_at": observation.inserted_at.isoformat(),
|
||||||
|
"subject": observation.subject,
|
||||||
|
"observation_type": observation.observation_type,
|
||||||
|
"confidence": cast(float, observation.confidence),
|
||||||
|
"tags": observation.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating observation: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_observations(
|
||||||
|
query: str,
|
||||||
|
subject: str = "",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
observation_types: list[str] | None = None,
|
||||||
|
min_confidence: float = 0.5,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Search through observations to understand the user better.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
This tool searches through all observations recorded about the user using
|
||||||
|
the 'observe' tool. Use it to recall past insights, check for patterns,
|
||||||
|
find contradictions, or understand the user's preferences before responding.
|
||||||
|
The more you use this tool, the more personalized and insightful your
|
||||||
|
responses can be.
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Before answering questions where user preferences might matter
|
||||||
|
- When the user references something from the past
|
||||||
|
- To check if current behavior aligns with past patterns
|
||||||
|
- To find related observations on a topic
|
||||||
|
- To build context about the user's expertise or interests
|
||||||
|
- Whenever personalization would improve your response
|
||||||
|
|
||||||
|
How it works:
|
||||||
|
Uses hybrid search combining semantic similarity with keyword matching.
|
||||||
|
Searches across multiple embedding spaces (semantic meaning and temporal
|
||||||
|
context) to find relevant observations from different angles. This approach
|
||||||
|
ensures you find both conceptually related and specifically mentioned items.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Natural language description of what you're looking for. The search
|
||||||
|
matches both meaning and specific terms in observation content.
|
||||||
|
Examples:
|
||||||
|
- "programming preferences and coding style"
|
||||||
|
- "opinions about artificial intelligence and AI safety"
|
||||||
|
- "work habits productivity patterns when does user work best"
|
||||||
|
- "previous projects the user has worked on"
|
||||||
|
Pro tip: Use natural language but include key terms you expect to find.
|
||||||
|
|
||||||
|
subject: Filter by exact subject identifier. Must match subjects used when
|
||||||
|
creating observations (e.g., "programming_style", "work_habits").
|
||||||
|
Leave empty to search all subjects. Use this when you know the exact
|
||||||
|
subject category you want.
|
||||||
|
|
||||||
|
tags: Filter results to only observations with these tags. Observations must
|
||||||
|
have at least one matching tag. Use the same format as when creating:
|
||||||
|
- ["programming", "functional-programming"]
|
||||||
|
- ["context:work", "project:thesis"]
|
||||||
|
- ["domain:finance", "machine-learning"]
|
||||||
|
|
||||||
|
observation_types: Filter by type of observation:
|
||||||
|
- "belief": Opinions or beliefs the user holds
|
||||||
|
- "preference": Things they prefer or favor
|
||||||
|
- "behavior": Patterns in how they act or work
|
||||||
|
- "contradiction": Noted inconsistencies
|
||||||
|
- "general": Other observations
|
||||||
|
Leave as None to search all types.
|
||||||
|
|
||||||
|
min_confidence: Only return observations with confidence >= this value.
|
||||||
|
- Use 0.8+ for high-confidence facts
|
||||||
|
- Use 0.5-0.7 to include inferred observations
|
||||||
|
- Default 0.5 includes most observations
|
||||||
|
Range: 0.0 to 1.0
|
||||||
|
|
||||||
|
limit: Maximum results to return (1-100). Default 10. Increase when you
|
||||||
|
need comprehensive understanding of a topic.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of observations sorted by relevance, each containing:
|
||||||
|
- subject: What the observation is about
|
||||||
|
- content: The full observation text
|
||||||
|
- observation_type: Type of observation
|
||||||
|
- evidence: Supporting context/quotes if provided
|
||||||
|
- confidence: How certain the observation is (0-1)
|
||||||
|
- agent_model: Which AI model made the observation
|
||||||
|
- tags: All tags on this observation
|
||||||
|
- created_at: When it was observed (if available)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Before discussing code architecture
|
||||||
|
results = await search_observations(
|
||||||
|
query="software architecture preferences microservices monoliths",
|
||||||
|
tags=["architecture"],
|
||||||
|
min_confidence=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
# Understanding work style for scheduling
|
||||||
|
results = await search_observations(
|
||||||
|
query="when does user work best productivity schedule",
|
||||||
|
observation_types=["behavior", "preference"],
|
||||||
|
subject="work_schedule"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for AI safety views before discussing AI
|
||||||
|
results = await search_observations(
|
||||||
|
query="artificial intelligence safety alignment concerns",
|
||||||
|
observation_types=["belief"],
|
||||||
|
min_confidence=0.8,
|
||||||
|
limit=20
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find contradictions on a topic
|
||||||
|
results = await search_observations(
|
||||||
|
query="testing methodology unit tests integration",
|
||||||
|
observation_types=["contradiction"],
|
||||||
|
tags=["testing", "software-development"]
|
||||||
|
)
|
||||||
|
|
||||||
|
Best practices:
|
||||||
|
- Search before making assumptions about user preferences
|
||||||
|
- Use broad queries first, then filter with tags/types if too many results
|
||||||
|
- Check for contradictions when user says something unexpected
|
||||||
|
- Higher confidence observations are more reliable
|
||||||
|
- Recent observations may override older ones on same topic
|
||||||
|
"""
|
||||||
|
source_ids = None
|
||||||
|
if tags or observation_types:
|
||||||
|
with make_session() as session:
|
||||||
|
items_query = session.query(AgentObservation.id)
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
# Use PostgreSQL array overlap operator with proper array casting
|
||||||
|
items_query = items_query.filter(
|
||||||
|
AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||||
|
)
|
||||||
|
if observation_types:
|
||||||
|
items_query = items_query.filter(
|
||||||
|
AgentObservation.observation_type.in_(observation_types)
|
||||||
|
)
|
||||||
|
source_ids = [item.id for item in items_query.all()]
|
||||||
|
if not source_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
semantic_text = observation.generate_semantic_text(
|
||||||
|
subject=subject or "",
|
||||||
|
observation_type="".join(observation_types or []),
|
||||||
|
content=query,
|
||||||
|
evidence=None,
|
||||||
|
)
|
||||||
|
temporal = observation.generate_temporal_text(
|
||||||
|
subject=subject or "",
|
||||||
|
content=query,
|
||||||
|
confidence=0,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
results = await search(
|
||||||
|
[
|
||||||
|
extract.DataChunk(data=[query]),
|
||||||
|
extract.DataChunk(data=[semantic_text]),
|
||||||
|
extract.DataChunk(data=[temporal]),
|
||||||
|
],
|
||||||
|
previews=True,
|
||||||
|
modalities=["semantic", "temporal"],
|
||||||
|
limit=limit,
|
||||||
|
min_text_score=0.8,
|
||||||
|
filters=SearchFilters(
|
||||||
|
subject=subject,
|
||||||
|
confidence=min_confidence,
|
||||||
|
tags=tags,
|
||||||
|
observation_types=observation_types,
|
||||||
|
source_ids=source_ids,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
cast(dict, cast(dict, result.model_dump()).get("content")) for result in results
|
||||||
|
]
|
@ -18,6 +18,7 @@ from memory.common.db.models import (
|
|||||||
ArticleFeed,
|
ArticleFeed,
|
||||||
EmailAccount,
|
EmailAccount,
|
||||||
ForumPost,
|
ForumPost,
|
||||||
|
AgentObservation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -53,6 +54,7 @@ class SourceItemAdmin(ModelView, model=SourceItem):
|
|||||||
|
|
||||||
class ChunkAdmin(ModelView, model=Chunk):
|
class ChunkAdmin(ModelView, model=Chunk):
|
||||||
column_list = ["id", "source_id", "embedding_model", "created_at"]
|
column_list = ["id", "source_id", "embedding_model", "created_at"]
|
||||||
|
column_sortable_list = ["created_at"]
|
||||||
|
|
||||||
|
|
||||||
class MailMessageAdmin(ModelView, model=MailMessage):
|
class MailMessageAdmin(ModelView, model=MailMessage):
|
||||||
@ -174,9 +176,23 @@ class EmailAccountAdmin(ModelView, model=EmailAccount):
|
|||||||
column_searchable_list = ["name", "email_address"]
|
column_searchable_list = ["name", "email_address"]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentObservationAdmin(ModelView, model=AgentObservation):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"content",
|
||||||
|
"subject",
|
||||||
|
"observation_type",
|
||||||
|
"confidence",
|
||||||
|
"evidence",
|
||||||
|
"inserted_at",
|
||||||
|
]
|
||||||
|
column_searchable_list = ["subject", "observation_type"]
|
||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
"""Add all admin views to the admin instance."""
|
"""Add all admin views to the admin instance."""
|
||||||
admin.add_view(SourceItemAdmin)
|
admin.add_view(SourceItemAdmin)
|
||||||
|
admin.add_view(AgentObservationAdmin)
|
||||||
admin.add_view(ChunkAdmin)
|
admin.add_view(ChunkAdmin)
|
||||||
admin.add_view(EmailAccountAdmin)
|
admin.add_view(EmailAccountAdmin)
|
||||||
admin.add_view(MailMessageAdmin)
|
admin.add_view(MailMessageAdmin)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
FastAPI application for the knowledge base.
|
FastAPI application for the knowledge base.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import pathlib
|
import pathlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
@ -15,15 +16,25 @@ from memory.common import extract
|
|||||||
from memory.common.db.connection import get_engine
|
from memory.common.db.connection import get_engine
|
||||||
from memory.api.admin import setup_admin
|
from memory.api.admin import setup_admin
|
||||||
from memory.api.search import search, SearchResult
|
from memory.api.search import search, SearchResult
|
||||||
|
from memory.api.MCP.tools import mcp
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title="Knowledge Base API")
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
async with contextlib.AsyncExitStack() as stack:
|
||||||
|
await stack.enter_async_context(mcp.session_manager.run())
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Knowledge Base API", lifespan=lifespan)
|
||||||
|
|
||||||
# SQLAdmin setup
|
# SQLAdmin setup
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
admin = Admin(app, engine)
|
admin = Admin(app, engine)
|
||||||
setup_admin(admin)
|
setup_admin(admin)
|
||||||
|
app.mount("/", mcp.streamable_http_app())
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@ -104,4 +115,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from memory.common.qdrant import setup_qdrant
|
||||||
|
|
||||||
|
setup_qdrant()
|
||||||
main()
|
main()
|
||||||
|
@ -2,21 +2,29 @@
|
|||||||
Search endpoints for the knowledge base API.
|
Search endpoints for the knowledge base API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
from hashlib import sha256
|
||||||
import io
|
import io
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Callable, Optional
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Callable, Optional, TypedDict, NotRequired
|
||||||
|
|
||||||
|
import bm25s
|
||||||
|
import Stemmer
|
||||||
|
import qdrant_client
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import qdrant_client
|
|
||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
|
|
||||||
from memory.common import embedding, qdrant, extract, settings
|
from memory.common import embedding, extract, qdrant, settings
|
||||||
from memory.common.collections import TEXT_COLLECTIONS, ALL_COLLECTIONS
|
from memory.common.collections import (
|
||||||
|
ALL_COLLECTIONS,
|
||||||
|
MULTIMODAL_COLLECTIONS,
|
||||||
|
TEXT_COLLECTIONS,
|
||||||
|
)
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Chunk, SourceItem
|
from memory.common.db.models import Chunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -28,6 +36,17 @@ class AnnotatedChunk(BaseModel):
|
|||||||
preview: Optional[str | None] = None
|
preview: Optional[str | None] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SourceData(BaseModel):
|
||||||
|
"""Holds source item data to avoid SQLAlchemy session issues"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
size: int | None
|
||||||
|
mime_type: str | None
|
||||||
|
filename: str | None
|
||||||
|
content: str | dict | None
|
||||||
|
content_length: int
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
collection: str
|
collection: str
|
||||||
results: list[dict]
|
results: list[dict]
|
||||||
@ -38,16 +57,46 @@ class SearchResult(BaseModel):
|
|||||||
size: int
|
size: int
|
||||||
mime_type: str
|
mime_type: str
|
||||||
chunks: list[AnnotatedChunk]
|
chunks: list[AnnotatedChunk]
|
||||||
content: Optional[str] = None
|
content: Optional[str | dict] = None
|
||||||
filename: Optional[str] = None
|
filename: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilters(TypedDict):
|
||||||
|
subject: NotRequired[str | None]
|
||||||
|
confidence: NotRequired[float]
|
||||||
|
tags: NotRequired[list[str] | None]
|
||||||
|
observation_types: NotRequired[list[str] | None]
|
||||||
|
source_ids: NotRequired[list[int] | None]
|
||||||
|
|
||||||
|
|
||||||
|
async def with_timeout(
|
||||||
|
call, timeout: int = 2
|
||||||
|
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||||
|
"""
|
||||||
|
Run a function with a timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call: The function to run
|
||||||
|
timeout: The timeout in seconds
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(call, timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(f"Search timed out after {timeout}s")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def annotated_chunk(
|
def annotated_chunk(
|
||||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
||||||
) -> tuple[SourceItem, AnnotatedChunk]:
|
) -> tuple[SourceData, AnnotatedChunk]:
|
||||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
||||||
if not previews and not isinstance(item, str):
|
if not previews and not isinstance(item, str):
|
||||||
return None
|
return None
|
||||||
|
if not previews and isinstance(item, str):
|
||||||
|
return item[:100]
|
||||||
|
|
||||||
if isinstance(item, Image.Image):
|
if isinstance(item, Image.Image):
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
@ -68,7 +117,19 @@ def annotated_chunk(
|
|||||||
for k, v in metadata.items()
|
for k, v in metadata.items()
|
||||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
if k not in ["content", "filename", "size", "content_type", "tags"]
|
||||||
}
|
}
|
||||||
return chunk.source, AnnotatedChunk(
|
|
||||||
|
# Prefetch all needed source data while in session
|
||||||
|
source = chunk.source
|
||||||
|
source_data = SourceData(
|
||||||
|
id=source.id,
|
||||||
|
size=source.size,
|
||||||
|
mime_type=source.mime_type,
|
||||||
|
filename=source.filename,
|
||||||
|
content=source.display_contents,
|
||||||
|
content_length=len(source.content) if source.content else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return source_data, AnnotatedChunk(
|
||||||
id=str(chunk.id),
|
id=str(chunk.id),
|
||||||
score=search_result.score,
|
score=search_result.score,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
@ -76,24 +137,28 @@ def annotated_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def group_chunks(chunks: list[tuple[SourceItem, AnnotatedChunk]]) -> list[SearchResult]:
|
def group_chunks(chunks: list[tuple[SourceData, AnnotatedChunk]]) -> list[SearchResult]:
|
||||||
items = defaultdict(list)
|
items = defaultdict(list)
|
||||||
|
source_lookup = {}
|
||||||
|
|
||||||
for source, chunk in chunks:
|
for source, chunk in chunks:
|
||||||
items[source].append(chunk)
|
items[source.id].append(chunk)
|
||||||
|
source_lookup[source.id] = source
|
||||||
|
|
||||||
return [
|
return [
|
||||||
SearchResult(
|
SearchResult(
|
||||||
id=source.id,
|
id=source.id,
|
||||||
size=source.size or len(source.content),
|
size=source.size or source.content_length,
|
||||||
mime_type=source.mime_type or "text/plain",
|
mime_type=source.mime_type or "text/plain",
|
||||||
filename=source.filename
|
filename=source.filename
|
||||||
and source.filename.replace(
|
and source.filename.replace(
|
||||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
||||||
),
|
),
|
||||||
content=source.display_contents,
|
content=source.content,
|
||||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
||||||
)
|
)
|
||||||
for source, chunks in items.items()
|
for source_id, chunks in items.items()
|
||||||
|
for source in [source_lookup[source_id]]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -104,25 +169,31 @@ def query_chunks(
|
|||||||
embedder: Callable,
|
embedder: Callable,
|
||||||
min_score: float = 0.0,
|
min_score: float = 0.0,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
) -> dict[str, list[qdrant_models.ScoredPoint]]:
|
||||||
if not upload_data:
|
if not upload_data or not allowed_modalities:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
chunks = [chunk for data_chunk in upload_data for chunk in data_chunk.data]
|
chunks = [chunk for chunk in upload_data if chunk.data]
|
||||||
if not chunks:
|
if not chunks:
|
||||||
logger.error(f"No chunks to embed for {allowed_modalities}")
|
logger.error(f"No chunks to embed for {allowed_modalities}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
vector = embedder(chunks, input_type="query")[0]
|
logger.error(f"Embedding {len(chunks)} chunks for {allowed_modalities}")
|
||||||
|
for c in chunks:
|
||||||
|
logger.error(f"Chunk: {c.data}")
|
||||||
|
vectors = embedder([c.data for c in chunks], input_type="query")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
collection: [
|
collection: [
|
||||||
r
|
r
|
||||||
|
for vector in vectors
|
||||||
for r in qdrant.search_vectors(
|
for r in qdrant.search_vectors(
|
||||||
client=client,
|
client=client,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
query_vector=vector,
|
query_vector=vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
filter_params=filters,
|
||||||
)
|
)
|
||||||
if r.score >= min_score
|
if r.score >= min_score
|
||||||
]
|
]
|
||||||
@ -130,6 +201,122 @@ def query_chunks(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def search_bm25(
|
||||||
|
query: str,
|
||||||
|
modalities: list[str],
|
||||||
|
limit: int = 10,
|
||||||
|
filters: SearchFilters = SearchFilters(),
|
||||||
|
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||||
|
with make_session() as db:
|
||||||
|
items_query = db.query(Chunk.id, Chunk.content).filter(
|
||||||
|
Chunk.collection_name.in_(modalities)
|
||||||
|
)
|
||||||
|
if source_ids := filters.get("source_ids"):
|
||||||
|
items_query = items_query.filter(Chunk.source_id.in_(source_ids))
|
||||||
|
items = items_query.all()
|
||||||
|
item_ids = {
|
||||||
|
sha256(item.content.lower().strip().encode("utf-8")).hexdigest(): item.id
|
||||||
|
for item in items
|
||||||
|
}
|
||||||
|
corpus = [item.content.lower().strip() for item in items]
|
||||||
|
|
||||||
|
stemmer = Stemmer.Stemmer("english")
|
||||||
|
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
|
||||||
|
retriever = bm25s.BM25()
|
||||||
|
retriever.index(corpus_tokens)
|
||||||
|
|
||||||
|
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
|
||||||
|
results, scores = retriever.retrieve(
|
||||||
|
query_tokens, k=min(limit, len(corpus)), corpus=corpus
|
||||||
|
)
|
||||||
|
|
||||||
|
item_scores = {
|
||||||
|
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
|
||||||
|
for doc, score in zip(results[0], scores[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
with make_session() as db:
|
||||||
|
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
|
||||||
|
results = []
|
||||||
|
for chunk in chunks:
|
||||||
|
# Prefetch all needed source data while in session
|
||||||
|
source = chunk.source
|
||||||
|
source_data = SourceData(
|
||||||
|
id=source.id,
|
||||||
|
size=source.size,
|
||||||
|
mime_type=source.mime_type,
|
||||||
|
filename=source.filename,
|
||||||
|
content=source.display_contents,
|
||||||
|
content_length=len(source.content) if source.content else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
annotated = AnnotatedChunk(
|
||||||
|
id=str(chunk.id),
|
||||||
|
score=item_scores[chunk.id],
|
||||||
|
metadata=source.as_payload(),
|
||||||
|
preview=None,
|
||||||
|
)
|
||||||
|
results.append((source_data, annotated))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def search_embeddings(
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
previews: Optional[bool] = False,
|
||||||
|
modalities: set[str] = set(),
|
||||||
|
limit: int = 10,
|
||||||
|
min_score: float = 0.3,
|
||||||
|
filters: SearchFilters = SearchFilters(),
|
||||||
|
multimodal: bool = False,
|
||||||
|
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||||
|
"""
|
||||||
|
Search across knowledge base using text query and optional files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- data: List of data to search in (e.g., text, images, files)
|
||||||
|
- previews: Whether to include previews in the search results
|
||||||
|
- modalities: List of modalities to search in (e.g., "text", "photo", "doc")
|
||||||
|
- limit: Maximum number of results
|
||||||
|
- min_score: Minimum score to include in the search results
|
||||||
|
- filters: Filters to apply to the search results
|
||||||
|
- multimodal: Whether to search in multimodal collections
|
||||||
|
"""
|
||||||
|
query_filters = {
|
||||||
|
"must": [
|
||||||
|
{"key": "confidence", "range": {"gte": filters.get("confidence", 0.5)}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
if tags := filters.get("tags"):
|
||||||
|
query_filters["must"] += [{"key": "tags", "match": {"any": tags}}]
|
||||||
|
if observation_types := filters.get("observation_types"):
|
||||||
|
query_filters["must"] += [
|
||||||
|
{"key": "observation_type", "match": {"any": observation_types}}
|
||||||
|
]
|
||||||
|
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
results = query_chunks(
|
||||||
|
client,
|
||||||
|
data,
|
||||||
|
modalities,
|
||||||
|
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||||
|
min_score=min_score,
|
||||||
|
limit=limit,
|
||||||
|
filters=query_filters,
|
||||||
|
)
|
||||||
|
search_results = {k: results.get(k, []) for k in modalities}
|
||||||
|
|
||||||
|
found_chunks = {
|
||||||
|
str(r.id): r for results in search_results.values() for r in results
|
||||||
|
}
|
||||||
|
with make_session() as db:
|
||||||
|
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
||||||
|
return [
|
||||||
|
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
previews: Optional[bool] = False,
|
previews: Optional[bool] = False,
|
||||||
@ -137,6 +324,7 @@ async def search(
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
min_text_score: float = 0.3,
|
min_text_score: float = 0.3,
|
||||||
min_multimodal_score: float = 0.3,
|
min_multimodal_score: float = 0.3,
|
||||||
|
filters: SearchFilters = {},
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search across knowledge base using text query and optional files.
|
Search across knowledge base using text query and optional files.
|
||||||
@ -150,40 +338,45 @@ async def search(
|
|||||||
Returns:
|
Returns:
|
||||||
- List of search results sorted by score
|
- List of search results sorted by score
|
||||||
"""
|
"""
|
||||||
client = qdrant.get_qdrant_client()
|
|
||||||
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
allowed_modalities = set(modalities or ALL_COLLECTIONS.keys())
|
||||||
text_results = query_chunks(
|
|
||||||
client,
|
|
||||||
data,
|
|
||||||
allowed_modalities & TEXT_COLLECTIONS,
|
|
||||||
embedding.embed_text,
|
|
||||||
min_score=min_text_score,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
multimodal_results = query_chunks(
|
|
||||||
client,
|
|
||||||
data,
|
|
||||||
allowed_modalities,
|
|
||||||
embedding.embed_mixed,
|
|
||||||
min_score=min_multimodal_score,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
search_results = {
|
|
||||||
k: text_results.get(k, []) + multimodal_results.get(k, [])
|
|
||||||
for k in allowed_modalities
|
|
||||||
}
|
|
||||||
|
|
||||||
found_chunks = {
|
text_embeddings_results = with_timeout(
|
||||||
str(r.id): r for results in search_results.values() for r in results
|
search_embeddings(
|
||||||
}
|
data,
|
||||||
with make_session() as db:
|
previews,
|
||||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
allowed_modalities & TEXT_COLLECTIONS,
|
||||||
logger.error(f"Found chunks: {chunks}")
|
limit,
|
||||||
|
min_text_score,
|
||||||
results = group_chunks(
|
filters,
|
||||||
[
|
multimodal=False,
|
||||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
|
||||||
for chunk in chunks
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
multimodal_embeddings_results = with_timeout(
|
||||||
|
search_embeddings(
|
||||||
|
data,
|
||||||
|
previews,
|
||||||
|
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||||
|
limit,
|
||||||
|
min_multimodal_score,
|
||||||
|
filters,
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bm25_results = with_timeout(
|
||||||
|
search_bm25(
|
||||||
|
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
||||||
|
modalities,
|
||||||
|
limit=limit,
|
||||||
|
filters=filters,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
text_embeddings_results,
|
||||||
|
multimodal_embeddings_results,
|
||||||
|
bm25_results,
|
||||||
|
return_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = group_chunks([c for r in results for c in r])
|
||||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Literal, NotRequired, TypedDict
|
from typing import Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
@ -14,73 +15,92 @@ Vector = list[float]
|
|||||||
class Collection(TypedDict):
|
class Collection(TypedDict):
|
||||||
dimension: int
|
dimension: int
|
||||||
distance: DistanceType
|
distance: DistanceType
|
||||||
model: str
|
|
||||||
on_disk: NotRequired[bool]
|
on_disk: NotRequired[bool]
|
||||||
shards: NotRequired[int]
|
shards: NotRequired[int]
|
||||||
|
text: bool
|
||||||
|
multimodal: bool
|
||||||
|
|
||||||
|
|
||||||
ALL_COLLECTIONS: dict[str, Collection] = {
|
ALL_COLLECTIONS: dict[str, Collection] = {
|
||||||
"mail": {
|
"mail": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
"chat": {
|
"chat": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"git": {
|
"git": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
"book": {
|
"book": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
"blog": {
|
"blog": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"forum": {
|
"forum": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"text": {
|
"text": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.TEXT_EMBEDDING_MODEL,
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
# Multimodal
|
|
||||||
"photo": {
|
"photo": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": False,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"comic": {
|
"comic": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": False,
|
||||||
|
"multimodal": True,
|
||||||
},
|
},
|
||||||
"doc": {
|
"doc": {
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
"distance": "Cosine",
|
"distance": "Cosine",
|
||||||
"model": settings.MIXED_EMBEDDING_MODEL,
|
"text": False,
|
||||||
|
"multimodal": True,
|
||||||
|
},
|
||||||
|
# Observations
|
||||||
|
"semantic": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
|
},
|
||||||
|
"temporal": {
|
||||||
|
"dimension": 1024,
|
||||||
|
"distance": "Cosine",
|
||||||
|
"text": True,
|
||||||
|
"multimodal": False,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
TEXT_COLLECTIONS = {
|
TEXT_COLLECTIONS = {
|
||||||
coll
|
coll for coll, params in ALL_COLLECTIONS.items() if params.get("text")
|
||||||
for coll, params in ALL_COLLECTIONS.items()
|
|
||||||
if params["model"] == settings.TEXT_EMBEDDING_MODEL
|
|
||||||
}
|
}
|
||||||
MULTIMODAL_COLLECTIONS = {
|
MULTIMODAL_COLLECTIONS = {
|
||||||
coll
|
coll for coll, params in ALL_COLLECTIONS.items() if params.get("multimodal")
|
||||||
for coll, params in ALL_COLLECTIONS.items()
|
|
||||||
if params["model"] == settings.MIXED_EMBEDDING_MODEL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPES = {
|
TYPES = {
|
||||||
@ -108,5 +128,12 @@ def get_modality(mime_type: str) -> str:
|
|||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
def collection_model(collection: str) -> str | None:
|
def collection_model(
|
||||||
return ALL_COLLECTIONS.get(collection, {}).get("model")
|
collection: str, text: str, images: list[Image.Image]
|
||||||
|
) -> str | None:
|
||||||
|
config = ALL_COLLECTIONS.get(collection, {})
|
||||||
|
if images and config.get("multimodal"):
|
||||||
|
return settings.MIXED_EMBEDDING_MODEL
|
||||||
|
if text and config.get("text"):
|
||||||
|
return settings.TEXT_EMBEDDING_MODEL
|
||||||
|
return "unknown"
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Database utilities package.
|
Database utilities package.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from memory.common.db.models import Base
|
from memory.common.db.models.base import Base
|
||||||
from memory.common.db.connection import (
|
from memory.common.db.connection import (
|
||||||
get_engine,
|
get_engine,
|
||||||
get_session_factory,
|
get_session_factory,
|
||||||
|
61
src/memory/common/db/models/__init__.py
Normal file
61
src/memory/common/db/models/__init__.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from memory.common.db.models.base import Base
|
||||||
|
from memory.common.db.models.source_item import (
|
||||||
|
Chunk,
|
||||||
|
SourceItem,
|
||||||
|
clean_filename,
|
||||||
|
)
|
||||||
|
from memory.common.db.models.source_items import (
|
||||||
|
MailMessage,
|
||||||
|
EmailAttachment,
|
||||||
|
AgentObservation,
|
||||||
|
ChatMessage,
|
||||||
|
BlogPost,
|
||||||
|
Comic,
|
||||||
|
BookSection,
|
||||||
|
ForumPost,
|
||||||
|
GithubItem,
|
||||||
|
GitCommit,
|
||||||
|
Photo,
|
||||||
|
MiscDoc,
|
||||||
|
)
|
||||||
|
from memory.common.db.models.observations import (
|
||||||
|
ObservationContradiction,
|
||||||
|
ReactionPattern,
|
||||||
|
ObservationPattern,
|
||||||
|
BeliefCluster,
|
||||||
|
ConversationMetrics,
|
||||||
|
)
|
||||||
|
from memory.common.db.models.sources import (
|
||||||
|
Book,
|
||||||
|
ArticleFeed,
|
||||||
|
EmailAccount,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Base",
|
||||||
|
"Chunk",
|
||||||
|
"clean_filename",
|
||||||
|
"SourceItem",
|
||||||
|
"MailMessage",
|
||||||
|
"EmailAttachment",
|
||||||
|
"AgentObservation",
|
||||||
|
"ChatMessage",
|
||||||
|
"BlogPost",
|
||||||
|
"Comic",
|
||||||
|
"BookSection",
|
||||||
|
"ForumPost",
|
||||||
|
"GithubItem",
|
||||||
|
"GitCommit",
|
||||||
|
"Photo",
|
||||||
|
"MiscDoc",
|
||||||
|
# Observations
|
||||||
|
"ObservationContradiction",
|
||||||
|
"ReactionPattern",
|
||||||
|
"ObservationPattern",
|
||||||
|
"BeliefCluster",
|
||||||
|
"ConversationMetrics",
|
||||||
|
# Sources
|
||||||
|
"Book",
|
||||||
|
"ArticleFeed",
|
||||||
|
"EmailAccount",
|
||||||
|
]
|
4
src/memory/common/db/models/base.py
Normal file
4
src/memory/common/db/models/base.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
|
||||||
|
Base = declarative_base()
|
171
src/memory/common/db/models/observations.py
Normal file
171
src/memory/common/db/models/observations.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
Agent observation models for the epistemic sparring partner system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
ARRAY,
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
Text,
|
||||||
|
UUID,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class ObservationContradiction(Base):
|
||||||
|
"""
|
||||||
|
Tracks contradictions between observations.
|
||||||
|
Can be detected automatically or reported by agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "observation_contradiction"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
observation_1_id = Column(
|
||||||
|
BigInteger,
|
||||||
|
ForeignKey("agent_observation.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
observation_2_id = Column(
|
||||||
|
BigInteger,
|
||||||
|
ForeignKey("agent_observation.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
contradiction_type = Column(Text, nullable=False) # direct, implied, temporal
|
||||||
|
detected_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
detection_method = Column(Text, nullable=False) # manual, automatic, agent-reported
|
||||||
|
resolution = Column(Text) # How it was resolved, if at all
|
||||||
|
observation_metadata = Column(JSONB)
|
||||||
|
|
||||||
|
# Relationships - use string references to avoid circular imports
|
||||||
|
observation_1 = relationship(
|
||||||
|
"AgentObservation",
|
||||||
|
foreign_keys=[observation_1_id],
|
||||||
|
back_populates="contradictions_as_first",
|
||||||
|
)
|
||||||
|
observation_2 = relationship(
|
||||||
|
"AgentObservation",
|
||||||
|
foreign_keys=[observation_2_id],
|
||||||
|
back_populates="contradictions_as_second",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("obs_contra_obs1_idx", "observation_1_id"),
|
||||||
|
Index("obs_contra_obs2_idx", "observation_2_id"),
|
||||||
|
Index("obs_contra_type_idx", "contradiction_type"),
|
||||||
|
Index("obs_contra_method_idx", "detection_method"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReactionPattern(Base):
|
||||||
|
"""
|
||||||
|
Tracks patterns in how the user reacts to certain types of observations or challenges.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "reaction_pattern"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
trigger_type = Column(
|
||||||
|
Text, nullable=False
|
||||||
|
) # What kind of observation triggers this
|
||||||
|
reaction_type = Column(Text, nullable=False) # How user typically responds
|
||||||
|
frequency = Column(Numeric(3, 2), nullable=False) # How often this pattern appears
|
||||||
|
first_observed = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
last_observed = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
example_observations = Column(
|
||||||
|
ARRAY(BigInteger)
|
||||||
|
) # IDs of observations showing this pattern
|
||||||
|
reaction_metadata = Column(JSONB)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("reaction_trigger_idx", "trigger_type"),
|
||||||
|
Index("reaction_type_idx", "reaction_type"),
|
||||||
|
Index("reaction_frequency_idx", "frequency"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ObservationPattern(Base):
|
||||||
|
"""
|
||||||
|
Higher-level patterns detected across multiple observations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "observation_pattern"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
pattern_type = Column(Text, nullable=False) # behavioral, cognitive, emotional
|
||||||
|
description = Column(Text, nullable=False)
|
||||||
|
supporting_observations = Column(
|
||||||
|
ARRAY(BigInteger), nullable=False
|
||||||
|
) # Observation IDs
|
||||||
|
exceptions = Column(ARRAY(BigInteger)) # Observations that don't fit
|
||||||
|
confidence = Column(Numeric(3, 2), nullable=False, default=0.7)
|
||||||
|
validity_start = Column(DateTime(timezone=True))
|
||||||
|
validity_end = Column(DateTime(timezone=True))
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
pattern_metadata = Column(JSONB)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("obs_pattern_type_idx", "pattern_type"),
|
||||||
|
Index("obs_pattern_confidence_idx", "confidence"),
|
||||||
|
Index("obs_pattern_validity_idx", "validity_start", "validity_end"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BeliefCluster(Base):
|
||||||
|
"""
|
||||||
|
Groups of related beliefs that support or depend on each other.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "belief_cluster"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
cluster_name = Column(Text, nullable=False)
|
||||||
|
core_beliefs = Column(ARRAY(Text), nullable=False)
|
||||||
|
peripheral_beliefs = Column(ARRAY(Text))
|
||||||
|
internal_consistency = Column(Numeric(3, 2)) # How well beliefs align
|
||||||
|
supporting_observations = Column(ARRAY(BigInteger)) # Observation IDs
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
last_updated = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
cluster_metadata = Column(JSONB)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("belief_cluster_name_idx", "cluster_name"),
|
||||||
|
Index("belief_cluster_consistency_idx", "internal_consistency"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationMetrics(Base):
|
||||||
|
"""
|
||||||
|
Tracks the effectiveness and depth of conversations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "conversation_metrics"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
session_id = Column(UUID(as_uuid=True), nullable=False)
|
||||||
|
depth_score = Column(Numeric(3, 2)) # How deep the conversation went
|
||||||
|
breakthrough_count = Column(Integer, default=0)
|
||||||
|
challenge_acceptance = Column(Numeric(3, 2)) # How well challenges were received
|
||||||
|
new_insights = Column(Integer, default=0)
|
||||||
|
user_engagement = Column(Numeric(3, 2)) # Inferred engagement level
|
||||||
|
duration_minutes = Column(Integer)
|
||||||
|
observation_count = Column(Integer, default=0)
|
||||||
|
contradiction_count = Column(Integer, default=0)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
metrics_metadata = Column(JSONB)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("conv_metrics_session_idx", "session_id", unique=True),
|
||||||
|
Index("conv_metrics_depth_idx", "depth_score"),
|
||||||
|
Index("conv_metrics_breakthrough_idx", "breakthrough_count"),
|
||||||
|
)
|
294
src/memory/common/db/models/source_item.py
Normal file
294
src/memory/common/db/models/source_item.py
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
"""
|
||||||
|
Database models for the knowledge base system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
from typing import Any, Sequence, cast
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from sqlalchemy import (
|
||||||
|
ARRAY,
|
||||||
|
UUID,
|
||||||
|
BigInteger,
|
||||||
|
CheckConstraint,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
event,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import BYTEA
|
||||||
|
from sqlalchemy.orm import Session, relationship
|
||||||
|
|
||||||
|
from memory.common import settings
|
||||||
|
import memory.common.extract as extract
|
||||||
|
import memory.common.collections as collections
|
||||||
|
import memory.common.chunker as chunker
|
||||||
|
import memory.common.summarizer as summarizer
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(Session, "before_flush")
|
||||||
|
def handle_duplicate_sha256(session, flush_context, instances):
|
||||||
|
"""
|
||||||
|
Event listener that efficiently checks for duplicate sha256 values before flush
|
||||||
|
and removes items with duplicate sha256 from the session.
|
||||||
|
|
||||||
|
Uses a single query to identify all duplicates rather than querying for each item.
|
||||||
|
"""
|
||||||
|
# Find all SourceItem objects being added
|
||||||
|
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
|
||||||
|
if not new_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
items = {}
|
||||||
|
for item in new_items:
|
||||||
|
try:
|
||||||
|
if (sha256 := item.sha256) is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if sha256 in items:
|
||||||
|
session.expunge(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
items[sha256] = item
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not new_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Query database for existing items with these sha256 values in a single query
|
||||||
|
existing_sha256s = set(
|
||||||
|
row[0]
|
||||||
|
for row in session.query(SourceItem.sha256).filter(
|
||||||
|
SourceItem.sha256.in_(items.keys())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove objects with duplicate sha256 values from the session
|
||||||
|
for sha256 in existing_sha256s:
|
||||||
|
if sha256 in items:
|
||||||
|
session.expunge(items[sha256])
|
||||||
|
|
||||||
|
|
||||||
|
def clean_filename(filename: str) -> str:
|
||||||
|
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
||||||
|
|
||||||
|
|
||||||
|
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
if not image.filename: # type: ignore
|
||||||
|
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
|
||||||
|
image.save(filename)
|
||||||
|
image.filename = str(filename) # type: ignore
|
||||||
|
|
||||||
|
return [image.filename for image in images] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
|
||||||
|
return [chunk] + [
|
||||||
|
i
|
||||||
|
for i in images
|
||||||
|
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
final = {}
|
||||||
|
for m in metadata:
|
||||||
|
data = m.copy()
|
||||||
|
if tags := set(data.pop("tags", [])):
|
||||||
|
final["tags"] = tags | final.get("tags", set())
|
||||||
|
final |= data
|
||||||
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
|
||||||
|
if not content.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
||||||
|
|
||||||
|
summary, tags = summarizer.summarize(content)
|
||||||
|
full_text: extract.DataChunk = extract.DataChunk(
|
||||||
|
data=[content.strip(), *images], metadata={"tags": tags}
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks: list[extract.DataChunk] = [full_text]
|
||||||
|
tokens = chunker.approx_token_count(content)
|
||||||
|
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||||
|
chunks += [
|
||||||
|
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
|
||||||
|
for c in chunker.chunk_text(content)
|
||||||
|
]
|
||||||
|
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||||
|
|
||||||
|
return [c for c in chunks if c.data]
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(Base):
|
||||||
|
"""Stores content chunks with their vector embeddings."""
|
||||||
|
|
||||||
|
__tablename__ = "chunk"
|
||||||
|
|
||||||
|
# The ID is also used as the vector ID in the vector database
|
||||||
|
id = Column(
|
||||||
|
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
|
||||||
|
)
|
||||||
|
source_id = Column(
|
||||||
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
file_paths = Column(
|
||||||
|
ARRAY(Text), nullable=True
|
||||||
|
) # Path to content if stored as a file
|
||||||
|
content = Column(Text) # Direct content storage
|
||||||
|
embedding_model = Column(Text)
|
||||||
|
collection_name = Column(Text)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
vector: list[float] = []
|
||||||
|
item_metadata: dict[str, Any] = {}
|
||||||
|
images: list[Image.Image] = []
|
||||||
|
|
||||||
|
# One of file_path or content must be populated
|
||||||
|
__table_args__ = (
|
||||||
|
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
||||||
|
Index("chunk_source_idx", "source_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunks(self) -> list[extract.MulitmodalChunk]:
|
||||||
|
chunks: list[extract.MulitmodalChunk] = []
|
||||||
|
if cast(str | None, self.content):
|
||||||
|
chunks = [cast(str, self.content)]
|
||||||
|
if self.images:
|
||||||
|
chunks += self.images
|
||||||
|
elif cast(Sequence[str] | None, self.file_paths):
|
||||||
|
chunks += [
|
||||||
|
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
|
||||||
|
]
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> list[bytes | str | Image.Image]:
|
||||||
|
if self.file_paths is None:
|
||||||
|
return [cast(str, self.content)]
|
||||||
|
|
||||||
|
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
|
||||||
|
files = [path for path in paths if path.exists()]
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for file_path in files:
|
||||||
|
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||||
|
if file_path.exists():
|
||||||
|
items.append(Image.open(file_path))
|
||||||
|
elif file_path.suffix == ".bin":
|
||||||
|
items.append(file_path.read_bytes())
|
||||||
|
else:
|
||||||
|
items.append(file_path.read_text())
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
class SourceItem(Base):
|
||||||
|
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||||
|
|
||||||
|
__tablename__ = "source_item"
|
||||||
|
__allow_unmapped__ = True
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
modality = Column(Text, nullable=False)
|
||||||
|
sha256 = Column(BYTEA, nullable=False, unique=True)
|
||||||
|
inserted_at = Column(
|
||||||
|
DateTime(timezone=True), server_default=func.now(), default=func.now()
|
||||||
|
)
|
||||||
|
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
size = Column(Integer)
|
||||||
|
mime_type = Column(Text)
|
||||||
|
|
||||||
|
# Content is stored in the database if it's small enough and text
|
||||||
|
content = Column(Text)
|
||||||
|
# Otherwise the content is stored on disk
|
||||||
|
filename = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Chunks relationship
|
||||||
|
embed_status = Column(Text, nullable=False, server_default="RAW")
|
||||||
|
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
# Discriminator column for SQLAlchemy inheritance
|
||||||
|
type = Column(String(50))
|
||||||
|
|
||||||
|
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
|
||||||
|
|
||||||
|
# Add table-level constraint and indexes
|
||||||
|
__table_args__ = (
|
||||||
|
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||||
|
Index("source_modality_idx", "modality"),
|
||||||
|
Index("source_status_idx", "embed_status"),
|
||||||
|
Index("source_tags_idx", "tags", postgresql_using="gin"),
|
||||||
|
Index("source_filename_idx", "filename"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vector_ids(self):
|
||||||
|
"""Get vector IDs from associated chunks."""
|
||||||
|
return [chunk.id for chunk in self.chunks]
|
||||||
|
|
||||||
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
|
chunks: list[extract.DataChunk] = []
|
||||||
|
content = cast(str | None, self.content)
|
||||||
|
if content:
|
||||||
|
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
|
||||||
|
|
||||||
|
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||||
|
summary, tags = summarizer.summarize(content)
|
||||||
|
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
||||||
|
|
||||||
|
mime_type = cast(str | None, self.mime_type)
|
||||||
|
if mime_type and mime_type.startswith("image/"):
|
||||||
|
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _make_chunk(
|
||||||
|
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
|
||||||
|
) -> Chunk:
|
||||||
|
chunk_id = str(uuid.uuid4())
|
||||||
|
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
|
||||||
|
images = [c for c in data.data if isinstance(c, Image.Image)]
|
||||||
|
image_names = image_filenames(chunk_id, images)
|
||||||
|
|
||||||
|
modality = data.modality or cast(str, self.modality)
|
||||||
|
chunk = Chunk(
|
||||||
|
id=chunk_id,
|
||||||
|
source=self,
|
||||||
|
content=text or None,
|
||||||
|
images=images,
|
||||||
|
file_paths=image_names,
|
||||||
|
collection_name=modality,
|
||||||
|
embedding_model=collections.collection_model(modality, text, images),
|
||||||
|
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
||||||
|
)
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||||
|
return [self._make_chunk(data) for data in self._chunk_contents()]
|
||||||
|
|
||||||
|
def as_payload(self) -> dict:
|
||||||
|
return {
|
||||||
|
"source_id": self.id,
|
||||||
|
"tags": self.tags,
|
||||||
|
"size": self.size,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_contents(self) -> str | dict | None:
|
||||||
|
return {
|
||||||
|
"tags": self.tags,
|
||||||
|
"size": self.size,
|
||||||
|
}
|
@ -3,18 +3,14 @@ Database models for the knowledge base system.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Sequence, cast
|
from typing import Any, Sequence, cast
|
||||||
import uuid
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
ARRAY,
|
ARRAY,
|
||||||
UUID,
|
|
||||||
BigInteger,
|
BigInteger,
|
||||||
Boolean,
|
|
||||||
CheckConstraint,
|
CheckConstraint,
|
||||||
Column,
|
Column,
|
||||||
DateTime,
|
DateTime,
|
||||||
@ -22,273 +18,24 @@ from sqlalchemy import (
|
|||||||
Index,
|
Index,
|
||||||
Integer,
|
Integer,
|
||||||
Numeric,
|
Numeric,
|
||||||
String,
|
|
||||||
Text,
|
Text,
|
||||||
event,
|
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
|
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR, UUID
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.orm import Session, relationship
|
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
import memory.common.extract as extract
|
import memory.common.extract as extract
|
||||||
import memory.common.collections as collections
|
|
||||||
import memory.common.chunker as chunker
|
|
||||||
import memory.common.summarizer as summarizer
|
import memory.common.summarizer as summarizer
|
||||||
|
import memory.common.formatters.observation as observation
|
||||||
|
|
||||||
Base = declarative_base()
|
from memory.common.db.models.source_item import (
|
||||||
|
SourceItem,
|
||||||
|
Chunk,
|
||||||
@event.listens_for(Session, "before_flush")
|
clean_filename,
|
||||||
def handle_duplicate_sha256(session, flush_context, instances):
|
merge_metadata,
|
||||||
"""
|
chunk_mixed,
|
||||||
Event listener that efficiently checks for duplicate sha256 values before flush
|
)
|
||||||
and removes items with duplicate sha256 from the session.
|
|
||||||
|
|
||||||
Uses a single query to identify all duplicates rather than querying for each item.
|
|
||||||
"""
|
|
||||||
# Find all SourceItem objects being added
|
|
||||||
new_items = [obj for obj in session.new if isinstance(obj, SourceItem)]
|
|
||||||
if not new_items:
|
|
||||||
return
|
|
||||||
|
|
||||||
items = {}
|
|
||||||
for item in new_items:
|
|
||||||
try:
|
|
||||||
if (sha256 := item.sha256) is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if sha256 in items:
|
|
||||||
session.expunge(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
items[sha256] = item
|
|
||||||
except (AttributeError, TypeError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not new_items:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Query database for existing items with these sha256 values in a single query
|
|
||||||
existing_sha256s = set(
|
|
||||||
row[0]
|
|
||||||
for row in session.query(SourceItem.sha256).filter(
|
|
||||||
SourceItem.sha256.in_(items.keys())
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove objects with duplicate sha256 values from the session
|
|
||||||
for sha256 in existing_sha256s:
|
|
||||||
if sha256 in items:
|
|
||||||
session.expunge(items[sha256])
|
|
||||||
|
|
||||||
|
|
||||||
def clean_filename(filename: str) -> str:
|
|
||||||
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
|
||||||
|
|
||||||
|
|
||||||
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
|
|
||||||
for i, image in enumerate(images):
|
|
||||||
if not image.filename: # type: ignore
|
|
||||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
|
|
||||||
image.save(filename)
|
|
||||||
image.filename = str(filename) # type: ignore
|
|
||||||
|
|
||||||
return [image.filename for image in images] # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
|
|
||||||
return [chunk] + [
|
|
||||||
i
|
|
||||||
for i in images
|
|
||||||
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def merge_metadata(*metadata: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
final = {}
|
|
||||||
for m in metadata:
|
|
||||||
if tags := set(m.pop("tags", [])):
|
|
||||||
final["tags"] = tags | final.get("tags", set())
|
|
||||||
final |= m
|
|
||||||
return final
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_mixed(content: str, image_paths: Sequence[str]) -> list[extract.DataChunk]:
|
|
||||||
if not content.strip():
|
|
||||||
return []
|
|
||||||
|
|
||||||
images = [Image.open(settings.FILE_STORAGE_DIR / image) for image in image_paths]
|
|
||||||
|
|
||||||
summary, tags = summarizer.summarize(content)
|
|
||||||
full_text: extract.DataChunk = extract.DataChunk(
|
|
||||||
data=[content.strip(), *images], metadata={"tags": tags}
|
|
||||||
)
|
|
||||||
|
|
||||||
chunks: list[extract.DataChunk] = [full_text]
|
|
||||||
tokens = chunker.approx_token_count(content)
|
|
||||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
|
||||||
chunks += [
|
|
||||||
extract.DataChunk(data=add_pics(c, images), metadata={"tags": tags})
|
|
||||||
for c in chunker.chunk_text(content)
|
|
||||||
]
|
|
||||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
|
||||||
|
|
||||||
return [c for c in chunks if c.data]
|
|
||||||
|
|
||||||
|
|
||||||
class Chunk(Base):
|
|
||||||
"""Stores content chunks with their vector embeddings."""
|
|
||||||
|
|
||||||
__tablename__ = "chunk"
|
|
||||||
|
|
||||||
# The ID is also used as the vector ID in the vector database
|
|
||||||
id = Column(
|
|
||||||
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
|
|
||||||
)
|
|
||||||
source_id = Column(
|
|
||||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
file_paths = Column(
|
|
||||||
ARRAY(Text), nullable=True
|
|
||||||
) # Path to content if stored as a file
|
|
||||||
content = Column(Text) # Direct content storage
|
|
||||||
embedding_model = Column(Text)
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
vector: list[float] = []
|
|
||||||
item_metadata: dict[str, Any] = {}
|
|
||||||
images: list[Image.Image] = []
|
|
||||||
|
|
||||||
# One of file_path or content must be populated
|
|
||||||
__table_args__ = (
|
|
||||||
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
|
|
||||||
Index("chunk_source_idx", "source_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def chunks(self) -> list[extract.MulitmodalChunk]:
|
|
||||||
chunks: list[extract.MulitmodalChunk] = []
|
|
||||||
if cast(str | None, self.content):
|
|
||||||
chunks = [cast(str, self.content)]
|
|
||||||
if self.images:
|
|
||||||
chunks += self.images
|
|
||||||
elif cast(Sequence[str] | None, self.file_paths):
|
|
||||||
chunks += [
|
|
||||||
Image.open(pathlib.Path(cast(str, cp))) for cp in self.file_paths
|
|
||||||
]
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self) -> list[bytes | str | Image.Image]:
|
|
||||||
if self.file_paths is None:
|
|
||||||
return [cast(str, self.content)]
|
|
||||||
|
|
||||||
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
|
|
||||||
files = [path for path in paths if path.exists()]
|
|
||||||
|
|
||||||
items = []
|
|
||||||
for file_path in files:
|
|
||||||
if file_path.suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
|
||||||
if file_path.exists():
|
|
||||||
items.append(Image.open(file_path))
|
|
||||||
elif file_path.suffix == ".bin":
|
|
||||||
items.append(file_path.read_bytes())
|
|
||||||
else:
|
|
||||||
items.append(file_path.read_text())
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
class SourceItem(Base):
|
|
||||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
|
||||||
|
|
||||||
__tablename__ = "source_item"
|
|
||||||
__allow_unmapped__ = True
|
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True)
|
|
||||||
modality = Column(Text, nullable=False)
|
|
||||||
sha256 = Column(BYTEA, nullable=False, unique=True)
|
|
||||||
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
|
||||||
size = Column(Integer)
|
|
||||||
mime_type = Column(Text)
|
|
||||||
|
|
||||||
# Content is stored in the database if it's small enough and text
|
|
||||||
content = Column(Text)
|
|
||||||
# Otherwise the content is stored on disk
|
|
||||||
filename = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
# Chunks relationship
|
|
||||||
embed_status = Column(Text, nullable=False, server_default="RAW")
|
|
||||||
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
|
||||||
|
|
||||||
# Discriminator column for SQLAlchemy inheritance
|
|
||||||
type = Column(String(50))
|
|
||||||
|
|
||||||
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
|
|
||||||
|
|
||||||
# Add table-level constraint and indexes
|
|
||||||
__table_args__ = (
|
|
||||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
|
||||||
Index("source_modality_idx", "modality"),
|
|
||||||
Index("source_status_idx", "embed_status"),
|
|
||||||
Index("source_tags_idx", "tags", postgresql_using="gin"),
|
|
||||||
Index("source_filename_idx", "filename"),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vector_ids(self):
|
|
||||||
"""Get vector IDs from associated chunks."""
|
|
||||||
return [chunk.id for chunk in self.chunks]
|
|
||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
|
||||||
chunks: list[extract.DataChunk] = []
|
|
||||||
content = cast(str | None, self.content)
|
|
||||||
if content:
|
|
||||||
chunks = [extract.DataChunk(data=[c]) for c in chunker.chunk_text(content)]
|
|
||||||
|
|
||||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
|
||||||
summary, tags = summarizer.summarize(content)
|
|
||||||
chunks.append(extract.DataChunk(data=[summary], metadata={"tags": tags}))
|
|
||||||
|
|
||||||
mime_type = cast(str | None, self.mime_type)
|
|
||||||
if mime_type and mime_type.startswith("image/"):
|
|
||||||
chunks.append(extract.DataChunk(data=[Image.open(self.filename)]))
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def _make_chunk(
|
|
||||||
self, data: extract.DataChunk, metadata: dict[str, Any] = {}
|
|
||||||
) -> Chunk:
|
|
||||||
chunk_id = str(uuid.uuid4())
|
|
||||||
text = "\n\n".join(c for c in data.data if isinstance(c, str) and c.strip())
|
|
||||||
images = [c for c in data.data if isinstance(c, Image.Image)]
|
|
||||||
image_names = image_filenames(chunk_id, images)
|
|
||||||
|
|
||||||
chunk = Chunk(
|
|
||||||
id=chunk_id,
|
|
||||||
source=self,
|
|
||||||
content=text or None,
|
|
||||||
images=images,
|
|
||||||
file_paths=image_names,
|
|
||||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
|
||||||
item_metadata=merge_metadata(self.as_payload(), data.metadata, metadata),
|
|
||||||
)
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
|
||||||
return [self._make_chunk(data) for data in self._chunk_contents()]
|
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
|
||||||
return {
|
|
||||||
"source_id": self.id,
|
|
||||||
"tags": self.tags,
|
|
||||||
"size": self.size,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def display_contents(self) -> str | None:
|
|
||||||
return cast(str | None, self.content) or cast(str | None, self.filename)
|
|
||||||
|
|
||||||
|
|
||||||
class MailMessage(SourceItem):
|
class MailMessage(SourceItem):
|
||||||
@ -528,51 +275,6 @@ class Comic(SourceItem):
|
|||||||
return [extract.DataChunk(data=[image, description])]
|
return [extract.DataChunk(data=[image, description])]
|
||||||
|
|
||||||
|
|
||||||
class Book(Base):
|
|
||||||
"""Book-level metadata table"""
|
|
||||||
|
|
||||||
__tablename__ = "book"
|
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True)
|
|
||||||
isbn = Column(Text, unique=True)
|
|
||||||
title = Column(Text, nullable=False)
|
|
||||||
author = Column(Text)
|
|
||||||
publisher = Column(Text)
|
|
||||||
published = Column(DateTime(timezone=True))
|
|
||||||
language = Column(Text)
|
|
||||||
edition = Column(Text)
|
|
||||||
series = Column(Text)
|
|
||||||
series_number = Column(Integer)
|
|
||||||
total_pages = Column(Integer)
|
|
||||||
file_path = Column(Text)
|
|
||||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
|
||||||
|
|
||||||
# Metadata from ebook parser
|
|
||||||
book_metadata = Column(JSONB, name="metadata")
|
|
||||||
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("book_isbn_idx", "isbn"),
|
|
||||||
Index("book_author_idx", "author"),
|
|
||||||
Index("book_title_idx", "title"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
|
||||||
return {
|
|
||||||
**super().as_payload(),
|
|
||||||
"isbn": self.isbn,
|
|
||||||
"title": self.title,
|
|
||||||
"author": self.author,
|
|
||||||
"publisher": self.publisher,
|
|
||||||
"published": self.published,
|
|
||||||
"language": self.language,
|
|
||||||
"edition": self.edition,
|
|
||||||
"series": self.series,
|
|
||||||
"series_number": self.series_number,
|
|
||||||
} | (cast(dict, self.book_metadata) or {})
|
|
||||||
|
|
||||||
|
|
||||||
class BookSection(SourceItem):
|
class BookSection(SourceItem):
|
||||||
"""Individual sections/chapters of books"""
|
"""Individual sections/chapters of books"""
|
||||||
|
|
||||||
@ -797,58 +499,163 @@ class GithubItem(SourceItem):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ArticleFeed(Base):
|
class AgentObservation(SourceItem):
|
||||||
__tablename__ = "article_feeds"
|
"""
|
||||||
|
Records observations made by AI agents about the user.
|
||||||
|
This is the primary data model for the epistemic sparring partner.
|
||||||
|
"""
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True)
|
__tablename__ = "agent_observation"
|
||||||
url = Column(Text, nullable=False, unique=True)
|
|
||||||
title = Column(Text)
|
id = Column(
|
||||||
description = Column(Text)
|
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
|
||||||
check_interval = Column(
|
|
||||||
Integer, nullable=False, server_default="60", doc="Minutes between checks"
|
|
||||||
)
|
)
|
||||||
last_checked_at = Column(DateTime(timezone=True))
|
session_id = Column(
|
||||||
active = Column(Boolean, nullable=False, server_default="true")
|
UUID(as_uuid=True)
|
||||||
created_at = Column(
|
) # Groups observations from same conversation
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
observation_type = Column(
|
||||||
|
Text, nullable=False
|
||||||
|
) # belief, preference, pattern, contradiction, behavior
|
||||||
|
subject = Column(Text, nullable=False) # What/who the observation is about
|
||||||
|
confidence = Column(Numeric(3, 2), nullable=False, default=0.8) # 0.0-1.0
|
||||||
|
evidence = Column(JSONB) # Supporting context, quotes, etc.
|
||||||
|
agent_model = Column(Text, nullable=False) # Which AI model made this observation
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
contradictions_as_first = relationship(
|
||||||
|
"ObservationContradiction",
|
||||||
|
foreign_keys="ObservationContradiction.observation_1_id",
|
||||||
|
back_populates="observation_1",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
updated_at = Column(
|
contradictions_as_second = relationship(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
"ObservationContradiction",
|
||||||
|
foreign_keys="ObservationContradiction.observation_2_id",
|
||||||
|
back_populates="observation_2",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add indexes
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "agent_observation",
|
||||||
|
}
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("article_feeds_active_idx", "active", "last_checked_at"),
|
Index("agent_obs_session_idx", "session_id"),
|
||||||
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
|
Index("agent_obs_type_idx", "observation_type"),
|
||||||
|
Index("agent_obs_subject_idx", "subject"),
|
||||||
|
Index("agent_obs_confidence_idx", "confidence"),
|
||||||
|
Index("agent_obs_model_idx", "agent_model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
if not kwargs.get("modality"):
|
||||||
|
kwargs["modality"] = "observation"
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
class EmailAccount(Base):
|
def as_payload(self) -> dict:
|
||||||
__tablename__ = "email_accounts"
|
payload = {
|
||||||
|
**super().as_payload(),
|
||||||
|
"observation_type": self.observation_type,
|
||||||
|
"subject": self.subject,
|
||||||
|
"confidence": float(cast(Any, self.confidence)),
|
||||||
|
"evidence": self.evidence,
|
||||||
|
"agent_model": self.agent_model,
|
||||||
|
}
|
||||||
|
if self.session_id is not None:
|
||||||
|
payload["session_id"] = str(self.session_id)
|
||||||
|
return payload
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True)
|
@property
|
||||||
name = Column(Text, nullable=False)
|
def all_contradictions(self):
|
||||||
email_address = Column(Text, nullable=False, unique=True)
|
"""Get all contradictions involving this observation."""
|
||||||
imap_server = Column(Text, nullable=False)
|
return self.contradictions_as_first + self.contradictions_as_second
|
||||||
imap_port = Column(Integer, nullable=False, server_default="993")
|
|
||||||
username = Column(Text, nullable=False)
|
|
||||||
password = Column(Text, nullable=False)
|
|
||||||
use_ssl = Column(Boolean, nullable=False, server_default="true")
|
|
||||||
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
|
|
||||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
|
||||||
last_sync_at = Column(DateTime(timezone=True))
|
|
||||||
active = Column(Boolean, nullable=False, server_default="true")
|
|
||||||
created_at = Column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
updated_at = Column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add indexes
|
@property
|
||||||
__table_args__ = (
|
def display_contents(self) -> dict:
|
||||||
Index("email_accounts_address_idx", "email_address", unique=True),
|
return {
|
||||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
"subject": self.subject,
|
||||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
"content": self.content,
|
||||||
)
|
"observation_type": self.observation_type,
|
||||||
|
"evidence": self.evidence,
|
||||||
|
"confidence": self.confidence,
|
||||||
|
"agent_model": self.agent_model,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||||
|
"""
|
||||||
|
Generate multiple chunks for different embedding dimensions.
|
||||||
|
Each chunk goes to a different Qdrant collection for specialized search.
|
||||||
|
"""
|
||||||
|
# 1. Semantic chunk - standard content representation
|
||||||
|
semantic_text = observation.generate_semantic_text(
|
||||||
|
cast(str, self.subject),
|
||||||
|
cast(str, self.observation_type),
|
||||||
|
cast(str, self.content),
|
||||||
|
cast(observation.Evidence, self.evidence),
|
||||||
|
)
|
||||||
|
semantic_chunk = extract.DataChunk(
|
||||||
|
data=[semantic_text],
|
||||||
|
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||||
|
modality="semantic",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Temporal chunk - time-aware representation
|
||||||
|
temporal_text = observation.generate_temporal_text(
|
||||||
|
cast(str, self.subject),
|
||||||
|
cast(str, self.content),
|
||||||
|
cast(float, self.confidence),
|
||||||
|
cast(datetime, self.inserted_at),
|
||||||
|
)
|
||||||
|
temporal_chunk = extract.DataChunk(
|
||||||
|
data=[temporal_text],
|
||||||
|
metadata=merge_metadata(metadata, {"embedding_type": "temporal"}),
|
||||||
|
modality="temporal",
|
||||||
|
)
|
||||||
|
|
||||||
|
others = [
|
||||||
|
self._make_chunk(
|
||||||
|
extract.DataChunk(
|
||||||
|
data=[i],
|
||||||
|
metadata=merge_metadata(metadata, {"embedding_type": "semantic"}),
|
||||||
|
modality="semantic",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in [
|
||||||
|
self.content,
|
||||||
|
self.subject,
|
||||||
|
self.observation_type,
|
||||||
|
self.evidence.get("quote", ""),
|
||||||
|
]
|
||||||
|
if i
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: Add more embedding dimensions here:
|
||||||
|
# 3. Epistemic chunk - belief structure focused
|
||||||
|
# epistemic_text = self._generate_epistemic_text()
|
||||||
|
# chunks.append(extract.DataChunk(
|
||||||
|
# data=[epistemic_text],
|
||||||
|
# metadata={**base_metadata, "embedding_type": "epistemic"},
|
||||||
|
# collection_name="observations_epistemic"
|
||||||
|
# ))
|
||||||
|
#
|
||||||
|
# 4. Emotional chunk - emotional context focused
|
||||||
|
# emotional_text = self._generate_emotional_text()
|
||||||
|
# chunks.append(extract.DataChunk(
|
||||||
|
# data=[emotional_text],
|
||||||
|
# metadata={**base_metadata, "embedding_type": "emotional"},
|
||||||
|
# collection_name="observations_emotional"
|
||||||
|
# ))
|
||||||
|
#
|
||||||
|
# 5. Relational chunk - connection patterns focused
|
||||||
|
# relational_text = self._generate_relational_text()
|
||||||
|
# chunks.append(extract.DataChunk(
|
||||||
|
# data=[relational_text],
|
||||||
|
# metadata={**base_metadata, "embedding_type": "relational"},
|
||||||
|
# collection_name="observations_relational"
|
||||||
|
# ))
|
||||||
|
|
||||||
|
return [
|
||||||
|
self._make_chunk(semantic_chunk),
|
||||||
|
self._make_chunk(temporal_chunk),
|
||||||
|
] + others
|
122
src/memory/common/db/models/sources.py
Normal file
122
src/memory/common/db/models/sources.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
Database models for the knowledge base system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
ARRAY,
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
Text,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Book(Base):
|
||||||
|
"""Book-level metadata table"""
|
||||||
|
|
||||||
|
__tablename__ = "book"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
isbn = Column(Text, unique=True)
|
||||||
|
title = Column(Text, nullable=False)
|
||||||
|
author = Column(Text)
|
||||||
|
publisher = Column(Text)
|
||||||
|
published = Column(DateTime(timezone=True))
|
||||||
|
language = Column(Text)
|
||||||
|
edition = Column(Text)
|
||||||
|
series = Column(Text)
|
||||||
|
series_number = Column(Integer)
|
||||||
|
total_pages = Column(Integer)
|
||||||
|
file_path = Column(Text)
|
||||||
|
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
# Metadata from ebook parser
|
||||||
|
book_metadata = Column(JSONB, name="metadata")
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("book_isbn_idx", "isbn"),
|
||||||
|
Index("book_author_idx", "author"),
|
||||||
|
Index("book_title_idx", "title"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_payload(self) -> dict:
|
||||||
|
return {
|
||||||
|
**super().as_payload(),
|
||||||
|
"isbn": self.isbn,
|
||||||
|
"title": self.title,
|
||||||
|
"author": self.author,
|
||||||
|
"publisher": self.publisher,
|
||||||
|
"published": self.published,
|
||||||
|
"language": self.language,
|
||||||
|
"edition": self.edition,
|
||||||
|
"series": self.series,
|
||||||
|
"series_number": self.series_number,
|
||||||
|
} | (cast(dict, self.book_metadata) or {})
|
||||||
|
|
||||||
|
|
||||||
|
class ArticleFeed(Base):
|
||||||
|
__tablename__ = "article_feeds"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
url = Column(Text, nullable=False, unique=True)
|
||||||
|
title = Column(Text)
|
||||||
|
description = Column(Text)
|
||||||
|
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
check_interval = Column(
|
||||||
|
Integer, nullable=False, server_default="60", doc="Minutes between checks"
|
||||||
|
)
|
||||||
|
last_checked_at = Column(DateTime(timezone=True))
|
||||||
|
active = Column(Boolean, nullable=False, server_default="true")
|
||||||
|
created_at = Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add indexes
|
||||||
|
__table_args__ = (
|
||||||
|
Index("article_feeds_active_idx", "active", "last_checked_at"),
|
||||||
|
Index("article_feeds_tags_idx", "tags", postgresql_using="gin"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailAccount(Base):
|
||||||
|
__tablename__ = "email_accounts"
|
||||||
|
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
email_address = Column(Text, nullable=False, unique=True)
|
||||||
|
imap_server = Column(Text, nullable=False)
|
||||||
|
imap_port = Column(Integer, nullable=False, server_default="993")
|
||||||
|
username = Column(Text, nullable=False)
|
||||||
|
password = Column(Text, nullable=False)
|
||||||
|
use_ssl = Column(Boolean, nullable=False, server_default="true")
|
||||||
|
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
last_sync_at = Column(DateTime(timezone=True))
|
||||||
|
active = Column(Boolean, nullable=False, server_default="true")
|
||||||
|
created_at = Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add indexes
|
||||||
|
__table_args__ = (
|
||||||
|
Index("email_accounts_address_idx", "email_address", unique=True),
|
||||||
|
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||||
|
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||||
|
)
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def embed_chunks(
|
def embed_chunks(
|
||||||
chunks: list[str] | list[list[extract.MulitmodalChunk]],
|
chunks: list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
@ -24,52 +24,50 @@ def embed_chunks(
|
|||||||
vo = voyageai.Client() # type: ignore
|
vo = voyageai.Client() # type: ignore
|
||||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||||
return vo.multimodal_embed(
|
return vo.multimodal_embed(
|
||||||
chunks, # type: ignore
|
chunks,
|
||||||
model=model,
|
model=model,
|
||||||
input_type=input_type,
|
input_type=input_type,
|
||||||
).embeddings
|
).embeddings
|
||||||
return vo.embed(chunks, model=model, input_type=input_type).embeddings # type: ignore
|
|
||||||
|
texts = ["\n".join(i for i in c if isinstance(i, str)) for c in chunks]
|
||||||
|
return cast(
|
||||||
|
list[Vector], vo.embed(texts, model=model, input_type=input_type).embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def break_chunk(
|
||||||
|
chunk: list[extract.MulitmodalChunk], chunk_size: int = DEFAULT_CHUNK_TOKENS
|
||||||
|
) -> list[extract.MulitmodalChunk]:
|
||||||
|
result = []
|
||||||
|
for c in chunk:
|
||||||
|
if isinstance(c, str):
|
||||||
|
result += chunk_text(c, chunk_size, OVERLAP_TOKENS)
|
||||||
|
else:
|
||||||
|
result.append(chunk)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def embed_text(
|
def embed_text(
|
||||||
texts: list[str],
|
chunks: list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.TEXT_EMBEDDING_MODEL,
|
model: str = settings.TEXT_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
chunks = [
|
chunked_chunks = [break_chunk(chunk, chunk_size) for chunk in chunks]
|
||||||
c
|
if not any(chunked_chunks):
|
||||||
for text in texts
|
|
||||||
if isinstance(text, str)
|
|
||||||
for c in chunk_text(text, chunk_size, OVERLAP_TOKENS)
|
|
||||||
if c.strip()
|
|
||||||
]
|
|
||||||
if not chunks:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
return embed_chunks(chunked_chunks, model, input_type)
|
||||||
return embed_chunks(chunks, model, input_type)
|
|
||||||
except voyageai.error.InvalidRequestError as e: # type: ignore
|
|
||||||
logger.error(f"Error embedding text: {e}")
|
|
||||||
logger.debug(f"Text: {texts}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def embed_mixed(
|
def embed_mixed(
|
||||||
items: list[extract.MulitmodalChunk],
|
items: list[list[extract.MulitmodalChunk]],
|
||||||
model: str = settings.MIXED_EMBEDDING_MODEL,
|
model: str = settings.MIXED_EMBEDDING_MODEL,
|
||||||
input_type: Literal["document", "query"] = "document",
|
input_type: Literal["document", "query"] = "document",
|
||||||
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
chunk_size: int = DEFAULT_CHUNK_TOKENS,
|
||||||
) -> list[Vector]:
|
) -> list[Vector]:
|
||||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[extract.MulitmodalChunk]:
|
chunked_chunks = [break_chunk(item, chunk_size) for item in items]
|
||||||
if isinstance(item, str):
|
return embed_chunks(chunked_chunks, model, input_type)
|
||||||
return [
|
|
||||||
c for c in chunk_text(item, chunk_size, OVERLAP_TOKENS) if c.strip()
|
|
||||||
]
|
|
||||||
return [item]
|
|
||||||
|
|
||||||
chunks = [c for item in items for c in to_chunks(item)]
|
|
||||||
return embed_chunks([chunks], model, input_type)
|
|
||||||
|
|
||||||
|
|
||||||
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
||||||
@ -87,6 +85,9 @@ def embed_by_model(chunks: list[Chunk], model: str) -> list[Chunk]:
|
|||||||
|
|
||||||
def embed_source_item(item: SourceItem) -> list[Chunk]:
|
def embed_source_item(item: SourceItem) -> list[Chunk]:
|
||||||
chunks = list(item.data_chunks())
|
chunks = list(item.data_chunks())
|
||||||
|
logger.error(
|
||||||
|
f"Embedding source item: {item.id} - {[(c.embedding_model, c.collection_name, c.chunks) for c in chunks]}"
|
||||||
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ class DataChunk:
|
|||||||
data: Sequence[MulitmodalChunk]
|
data: Sequence[MulitmodalChunk]
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
mime_type: str = "text/plain"
|
mime_type: str = "text/plain"
|
||||||
|
modality: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
87
src/memory/common/formatters/observation.py
Normal file
87
src/memory/common/formatters/observation.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class Evidence(TypedDict):
|
||||||
|
quote: str
|
||||||
|
context: str
|
||||||
|
|
||||||
|
|
||||||
|
def generate_semantic_text(
|
||||||
|
subject: str, observation_type: str, content: str, evidence: Evidence | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Generate text optimized for semantic similarity search."""
|
||||||
|
parts = [
|
||||||
|
f"Subject: {subject}",
|
||||||
|
f"Type: {observation_type}",
|
||||||
|
f"Observation: {content}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if not evidence or not isinstance(evidence, dict):
|
||||||
|
return " | ".join(parts)
|
||||||
|
|
||||||
|
if "quote" in evidence:
|
||||||
|
parts.append(f"Quote: {evidence['quote']}")
|
||||||
|
if "context" in evidence:
|
||||||
|
parts.append(f"Context: {evidence['context']}")
|
||||||
|
|
||||||
|
return " | ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_temporal_text(
|
||||||
|
subject: str,
|
||||||
|
content: str,
|
||||||
|
confidence: float,
|
||||||
|
created_at: datetime,
|
||||||
|
) -> str:
|
||||||
|
"""Generate text with temporal context for time-pattern search."""
|
||||||
|
# Add temporal markers
|
||||||
|
time_of_day = created_at.strftime("%H:%M")
|
||||||
|
day_of_week = created_at.strftime("%A")
|
||||||
|
|
||||||
|
# Categorize time periods
|
||||||
|
hour = created_at.hour
|
||||||
|
if 5 <= hour < 12:
|
||||||
|
time_period = "morning"
|
||||||
|
elif 12 <= hour < 17:
|
||||||
|
time_period = "afternoon"
|
||||||
|
elif 17 <= hour < 22:
|
||||||
|
time_period = "evening"
|
||||||
|
else:
|
||||||
|
time_period = "late_night"
|
||||||
|
|
||||||
|
parts = [
|
||||||
|
f"Time: {time_of_day} on {day_of_week} ({time_period})",
|
||||||
|
f"Subject: {subject}",
|
||||||
|
f"Observation: {content}",
|
||||||
|
]
|
||||||
|
if confidence:
|
||||||
|
parts.append(f"Confidence: {confidence}")
|
||||||
|
|
||||||
|
return " | ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Add more embedding dimensions here:
|
||||||
|
# 3. Epistemic chunk - belief structure focused
|
||||||
|
# epistemic_text = self._generate_epistemic_text()
|
||||||
|
# chunks.append(extract.DataChunk(
|
||||||
|
# data=[epistemic_text],
|
||||||
|
# metadata={**base_metadata, "embedding_type": "epistemic"},
|
||||||
|
# collection_name="observations_epistemic"
|
||||||
|
# ))
|
||||||
|
#
|
||||||
|
# 4. Emotional chunk - emotional context focused
|
||||||
|
# emotional_text = self._generate_emotional_text()
|
||||||
|
# chunks.append(extract.DataChunk(
|
||||||
|
# data=[emotional_text],
|
||||||
|
# metadata={**base_metadata, "embedding_type": "emotional"},
|
||||||
|
# collection_name="observations_emotional"
|
||||||
|
# ))
|
||||||
|
#
|
||||||
|
# 5. Relational chunk - connection patterns focused
|
||||||
|
# relational_text = self._generate_relational_text()
|
||||||
|
# chunks.append(extract.DataChunk(
|
||||||
|
# data=[relational_text],
|
||||||
|
# metadata={**base_metadata, "embedding_type": "relational"},
|
||||||
|
# collection_name="observations_relational"
|
||||||
|
# ))
|
@ -9,6 +9,8 @@ logger = logging.getLogger(__name__)
|
|||||||
TAGS_PROMPT = """
|
TAGS_PROMPT = """
|
||||||
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
|
The following text is already concise. Please identify 3-5 relevant tags that capture the main topics or themes.
|
||||||
|
|
||||||
|
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||||
|
|
||||||
Return your response as JSON with this format:
|
Return your response as JSON with this format:
|
||||||
{{
|
{{
|
||||||
"summary": "{summary}",
|
"summary": "{summary}",
|
||||||
@ -23,6 +25,8 @@ SUMMARY_PROMPT = """
|
|||||||
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
|
Please summarize the following text into approximately {target_tokens} tokens ({target_chars} characters).
|
||||||
Also provide 3-5 relevant tags that capture the main topics or themes.
|
Also provide 3-5 relevant tags that capture the main topics or themes.
|
||||||
|
|
||||||
|
Tags should be lowercase and use hyphens instead of spaces, e.g. "machine-learning" instead of "Machine Learning".
|
||||||
|
|
||||||
Return your response as JSON with this format:
|
Return your response as JSON with this format:
|
||||||
{{
|
{{
|
||||||
"summary": "your summary here",
|
"summary": "your summary here",
|
||||||
|
@ -1,62 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
from fastapi import UploadFile
|
|
||||||
import httpx
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
|
||||||
|
|
||||||
SERVER = "http://localhost:8000"
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
mcp = FastMCP("memory")
|
|
||||||
|
|
||||||
|
|
||||||
async def make_request(
|
|
||||||
path: str,
|
|
||||||
method: str,
|
|
||||||
data: dict | None = None,
|
|
||||||
json: dict | None = None,
|
|
||||||
files: list[UploadFile] | None = None,
|
|
||||||
) -> httpx.Response:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
return await client.request(
|
|
||||||
method,
|
|
||||||
f"{SERVER}/{path}",
|
|
||||||
data=data,
|
|
||||||
json=json,
|
|
||||||
files=files, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def post_data(path: str, data: dict | None = None) -> httpx.Response:
|
|
||||||
return await make_request(path, "POST", data=data)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def search(
|
|
||||||
query: str, previews: bool = False, modalities: list[str] = [], limit: int = 10
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
logger.error(f"Searching for {query}")
|
|
||||||
resp = await post_data(
|
|
||||||
"search",
|
|
||||||
{
|
|
||||||
"query": query,
|
|
||||||
"previews": previews,
|
|
||||||
"modalities": modalities,
|
|
||||||
"limit": limit,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Initialize and run the server
|
|
||||||
args = argparse.ArgumentParser()
|
|
||||||
args.add_argument("--server", type=str)
|
|
||||||
args = args.parse_args()
|
|
||||||
|
|
||||||
SERVER = args.server
|
|
||||||
|
|
||||||
mcp.run(transport=args.transport)
|
|
@ -6,13 +6,14 @@ the complete workflow: existence checking, content hashing, embedding generation
|
|||||||
vector storage, and result tracking.
|
vector storage, and result tracking.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
import hashlib
|
import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Iterable, Sequence, cast
|
from typing import Any, Callable, Iterable, Sequence, cast
|
||||||
|
|
||||||
from memory.common import embedding, qdrant
|
from memory.common import embedding, qdrant
|
||||||
from memory.common.db.models import SourceItem
|
from memory.common.db.models import SourceItem, Chunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -103,7 +104,19 @@ def embed_source_item(source_item: SourceItem) -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
def by_collection(chunks: Sequence[Chunk]) -> dict[str, dict[str, Any]]:
|
||||||
|
collections: dict[str, dict[str, list[Any]]] = defaultdict(
|
||||||
|
lambda: defaultdict(list)
|
||||||
|
)
|
||||||
|
for chunk in chunks:
|
||||||
|
collection = collections[cast(str, chunk.collection_name)]
|
||||||
|
collection["ids"].append(chunk.id)
|
||||||
|
collection["vectors"].append(chunk.vector)
|
||||||
|
collection["payloads"].append(chunk.item_metadata)
|
||||||
|
return collections
|
||||||
|
|
||||||
|
|
||||||
|
def push_to_qdrant(source_items: Sequence[SourceItem]):
|
||||||
"""
|
"""
|
||||||
Push embeddings to Qdrant vector database.
|
Push embeddings to Qdrant vector database.
|
||||||
|
|
||||||
@ -135,17 +148,16 @@ def push_to_qdrant(source_items: Sequence[SourceItem], collection_name: str):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vector_ids = [str(chunk.id) for chunk in all_chunks]
|
client = qdrant.get_qdrant_client()
|
||||||
vectors = [chunk.vector for chunk in all_chunks]
|
collections = by_collection(all_chunks)
|
||||||
payloads = [chunk.item_metadata for chunk in all_chunks]
|
for collection_name, collection in collections.items():
|
||||||
|
qdrant.upsert_vectors(
|
||||||
qdrant.upsert_vectors(
|
client=client,
|
||||||
client=qdrant.get_qdrant_client(),
|
collection_name=collection_name,
|
||||||
collection_name=collection_name,
|
ids=collection["ids"],
|
||||||
ids=vector_ids,
|
vectors=collection["vectors"],
|
||||||
vectors=vectors,
|
payloads=collection["payloads"],
|
||||||
payloads=payloads,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
for item in items_to_process:
|
for item in items_to_process:
|
||||||
item.embed_status = "STORED" # type: ignore
|
item.embed_status = "STORED" # type: ignore
|
||||||
@ -222,7 +234,7 @@ def process_content_item(item: SourceItem, session) -> dict[str, Any]:
|
|||||||
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
return create_task_result(item, status, content_length=getattr(item, "size", 0))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
push_to_qdrant([item], cast(str, item.modality))
|
push_to_qdrant([item])
|
||||||
status = "processed"
|
status = "processed"
|
||||||
item.embed_status = "STORED" # type: ignore
|
item.embed_status = "STORED" # type: ignore
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -185,7 +185,7 @@ def sync_book(file_path: str, tags: Iterable[str] = []) -> dict:
|
|||||||
f"Embedded section: {section.section_title} - {section.content[:100]}"
|
f"Embedded section: {section.section_title} - {section.content[:100]}"
|
||||||
)
|
)
|
||||||
logger.info("Pushing to Qdrant")
|
logger.info("Pushing to Qdrant")
|
||||||
push_to_qdrant(all_sections, "book")
|
push_to_qdrant(all_sections)
|
||||||
logger.info("Committing session")
|
logger.info("Committing session")
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from memory.common.db.models import SourceItem
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from unittest.mock import patch, Mock
|
from unittest.mock import patch, Mock
|
||||||
from typing import cast
|
from typing import cast
|
||||||
@ -6,17 +5,20 @@ import pytest
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from memory.common import settings, chunker, extract
|
from memory.common import settings, chunker, extract
|
||||||
from memory.common.db.models import (
|
from memory.common.db.models.sources import Book
|
||||||
|
from memory.common.db.models.source_items import (
|
||||||
Chunk,
|
Chunk,
|
||||||
clean_filename,
|
|
||||||
image_filenames,
|
|
||||||
add_pics,
|
|
||||||
MailMessage,
|
MailMessage,
|
||||||
EmailAttachment,
|
EmailAttachment,
|
||||||
BookSection,
|
BookSection,
|
||||||
BlogPost,
|
BlogPost,
|
||||||
Book,
|
)
|
||||||
|
from memory.common.db.models.source_item import (
|
||||||
|
SourceItem,
|
||||||
|
image_filenames,
|
||||||
|
add_pics,
|
||||||
merge_metadata,
|
merge_metadata,
|
||||||
|
clean_filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -585,288 +587,6 @@ def test_chunk_constraint_validation(
|
|||||||
assert chunk.id is not None
|
assert chunk.id is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"modality,expected_modality",
|
|
||||||
[
|
|
||||||
(None, "email"), # Default case
|
|
||||||
("custom", "custom"), # Override case
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_mail_message_modality(modality, expected_modality):
|
|
||||||
"""Test MailMessage modality setting"""
|
|
||||||
kwargs = {"sha256": b"test", "content": "test"}
|
|
||||||
if modality is not None:
|
|
||||||
kwargs["modality"] = modality
|
|
||||||
|
|
||||||
mail_message = MailMessage(**kwargs)
|
|
||||||
# The __init__ method should set the correct modality
|
|
||||||
assert hasattr(mail_message, "modality")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sender,folder,expected_path",
|
|
||||||
[
|
|
||||||
("user@example.com", "INBOX", "user_example_com/INBOX"),
|
|
||||||
("user+tag@example.com", "Sent Items", "user_tag_example_com/Sent_Items"),
|
|
||||||
("user@domain.co.uk", None, "user_domain_co_uk/INBOX"),
|
|
||||||
("user@domain.co.uk", "", "user_domain_co_uk/INBOX"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_mail_message_attachments_path(sender, folder, expected_path):
|
|
||||||
"""Test MailMessage.attachments_path property"""
|
|
||||||
mail_message = MailMessage(
|
|
||||||
sha256=b"test", content="test", sender=sender, folder=folder
|
|
||||||
)
|
|
||||||
|
|
||||||
result = mail_message.attachments_path
|
|
||||||
assert str(result) == f"{settings.FILE_STORAGE_DIR}/emails/{expected_path}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"filename,expected",
|
|
||||||
[
|
|
||||||
("document.pdf", "document.pdf"),
|
|
||||||
("file with spaces.txt", "file_with_spaces.txt"),
|
|
||||||
("file@#$%^&*().doc", "file.doc"),
|
|
||||||
("no-extension", "no_extension"),
|
|
||||||
("multiple.dots.in.name.txt", "multiple_dots_in_name.txt"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_mail_message_safe_filename(tmp_path, filename, expected):
|
|
||||||
"""Test MailMessage.safe_filename method"""
|
|
||||||
mail_message = MailMessage(
|
|
||||||
sha256=b"test", content="test", sender="user@example.com", folder="INBOX"
|
|
||||||
)
|
|
||||||
|
|
||||||
expected = settings.FILE_STORAGE_DIR / f"emails/user_example_com/INBOX/{expected}"
|
|
||||||
assert mail_message.safe_filename(filename) == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sent_at,expected_date",
|
|
||||||
[
|
|
||||||
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
|
|
||||||
(None, None),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_mail_message_as_payload(sent_at, expected_date):
|
|
||||||
"""Test MailMessage.as_payload method"""
|
|
||||||
|
|
||||||
mail_message = MailMessage(
|
|
||||||
sha256=b"test",
|
|
||||||
content="test",
|
|
||||||
message_id="<test@example.com>",
|
|
||||||
subject="Test Subject",
|
|
||||||
sender="sender@example.com",
|
|
||||||
recipients=["recipient1@example.com", "recipient2@example.com"],
|
|
||||||
folder="INBOX",
|
|
||||||
sent_at=sent_at,
|
|
||||||
tags=["tag1", "tag2"],
|
|
||||||
size=1024,
|
|
||||||
)
|
|
||||||
# Manually set id for testing
|
|
||||||
object.__setattr__(mail_message, "id", 123)
|
|
||||||
|
|
||||||
payload = mail_message.as_payload()
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"source_id": 123,
|
|
||||||
"size": 1024,
|
|
||||||
"message_id": "<test@example.com>",
|
|
||||||
"subject": "Test Subject",
|
|
||||||
"sender": "sender@example.com",
|
|
||||||
"recipients": ["recipient1@example.com", "recipient2@example.com"],
|
|
||||||
"folder": "INBOX",
|
|
||||||
"tags": [
|
|
||||||
"tag1",
|
|
||||||
"tag2",
|
|
||||||
"sender@example.com",
|
|
||||||
"recipient1@example.com",
|
|
||||||
"recipient2@example.com",
|
|
||||||
],
|
|
||||||
"date": expected_date,
|
|
||||||
}
|
|
||||||
assert payload == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_mail_message_parsed_content():
|
|
||||||
"""Test MailMessage.parsed_content property with actual email parsing"""
|
|
||||||
# Use a simple email format that the parser can handle
|
|
||||||
email_content = """From: sender@example.com
|
|
||||||
To: recipient@example.com
|
|
||||||
Subject: Test Subject
|
|
||||||
|
|
||||||
Test Body Content"""
|
|
||||||
|
|
||||||
mail_message = MailMessage(
|
|
||||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = mail_message.parsed_content
|
|
||||||
|
|
||||||
# Just test that it returns a dict-like object
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert "body" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_mail_message_body_property():
|
|
||||||
"""Test MailMessage.body property with actual email parsing"""
|
|
||||||
email_content = """From: sender@example.com
|
|
||||||
To: recipient@example.com
|
|
||||||
Subject: Test Subject
|
|
||||||
|
|
||||||
Test Body Content"""
|
|
||||||
|
|
||||||
mail_message = MailMessage(
|
|
||||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mail_message.body == "Test Body Content"
|
|
||||||
|
|
||||||
|
|
||||||
def test_mail_message_display_contents():
|
|
||||||
"""Test MailMessage.display_contents property with actual email parsing"""
|
|
||||||
email_content = """From: sender@example.com
|
|
||||||
To: recipient@example.com
|
|
||||||
Subject: Test Subject
|
|
||||||
|
|
||||||
Test Body Content"""
|
|
||||||
|
|
||||||
mail_message = MailMessage(
|
|
||||||
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
|
||||||
)
|
|
||||||
|
|
||||||
expected = (
|
|
||||||
"\nSubject: Test Subject\nFrom: \nTo: \nDate: \nBody: \nTest Body Content\n"
|
|
||||||
)
|
|
||||||
assert mail_message.display_contents == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"created_at,expected_date",
|
|
||||||
[
|
|
||||||
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
|
|
||||||
(None, None),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_email_attachment_as_payload(created_at, expected_date):
|
|
||||||
"""Test EmailAttachment.as_payload method"""
|
|
||||||
attachment = EmailAttachment(
|
|
||||||
sha256=b"test",
|
|
||||||
filename="document.pdf",
|
|
||||||
mime_type="application/pdf",
|
|
||||||
size=1024,
|
|
||||||
mail_message_id=123,
|
|
||||||
created_at=created_at,
|
|
||||||
tags=["pdf", "document"],
|
|
||||||
)
|
|
||||||
# Manually set id for testing
|
|
||||||
object.__setattr__(attachment, "id", 456)
|
|
||||||
|
|
||||||
payload = attachment.as_payload()
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"source_id": 456,
|
|
||||||
"filename": "document.pdf",
|
|
||||||
"content_type": "application/pdf",
|
|
||||||
"size": 1024,
|
|
||||||
"created_at": expected_date,
|
|
||||||
"mail_message_id": 123,
|
|
||||||
"tags": ["pdf", "document"],
|
|
||||||
}
|
|
||||||
assert payload == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"has_filename,content_source,expected_content",
|
|
||||||
[
|
|
||||||
(True, "file", b"test file content"),
|
|
||||||
(False, "content", "attachment content"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@patch("memory.common.extract.extract_data_chunks")
|
|
||||||
def test_email_attachment_data_chunks(
|
|
||||||
mock_extract, has_filename, content_source, expected_content, tmp_path
|
|
||||||
):
|
|
||||||
"""Test EmailAttachment.data_chunks method"""
|
|
||||||
from memory.common.extract import DataChunk
|
|
||||||
|
|
||||||
mock_extract.return_value = [
|
|
||||||
DataChunk(data=["extracted text"], metadata={"source": content_source})
|
|
||||||
]
|
|
||||||
|
|
||||||
if has_filename:
|
|
||||||
# Create a test file
|
|
||||||
test_file = tmp_path / "test.txt"
|
|
||||||
test_file.write_bytes(b"test file content")
|
|
||||||
attachment = EmailAttachment(
|
|
||||||
sha256=b"test",
|
|
||||||
filename=str(test_file),
|
|
||||||
mime_type="text/plain",
|
|
||||||
mail_message_id=123,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attachment = EmailAttachment(
|
|
||||||
sha256=b"test",
|
|
||||||
content="attachment content",
|
|
||||||
filename=None,
|
|
||||||
mime_type="text/plain",
|
|
||||||
mail_message_id=123,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock _make_chunk to return a simple chunk
|
|
||||||
mock_chunk = Mock()
|
|
||||||
with patch.object(attachment, "_make_chunk", return_value=mock_chunk) as mock_make:
|
|
||||||
result = attachment.data_chunks({"extra": "metadata"})
|
|
||||||
|
|
||||||
# Verify the method calls
|
|
||||||
mock_extract.assert_called_once_with("text/plain", expected_content)
|
|
||||||
mock_make.assert_called_once_with(
|
|
||||||
extract.DataChunk(data=["extracted text"], metadata={"source": content_source}),
|
|
||||||
{"extra": "metadata"},
|
|
||||||
)
|
|
||||||
assert result == [mock_chunk]
|
|
||||||
|
|
||||||
|
|
||||||
def test_email_attachment_cascade_delete(db_session: Session):
|
|
||||||
"""Test that EmailAttachment is deleted when MailMessage is deleted"""
|
|
||||||
mail_message = MailMessage(
|
|
||||||
sha256=b"test_email",
|
|
||||||
content="test email",
|
|
||||||
message_id="<test@example.com>",
|
|
||||||
subject="Test",
|
|
||||||
sender="sender@example.com",
|
|
||||||
recipients=["recipient@example.com"],
|
|
||||||
folder="INBOX",
|
|
||||||
)
|
|
||||||
db_session.add(mail_message)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
attachment = EmailAttachment(
|
|
||||||
sha256=b"test_attachment",
|
|
||||||
content="attachment content",
|
|
||||||
mail_message=mail_message,
|
|
||||||
filename="test.txt",
|
|
||||||
mime_type="text/plain",
|
|
||||||
size=100,
|
|
||||||
modality="attachment", # Set modality explicitly
|
|
||||||
)
|
|
||||||
db_session.add(attachment)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
attachment_id = attachment.id
|
|
||||||
|
|
||||||
# Delete the mail message
|
|
||||||
db_session.delete(mail_message)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify the attachment was also deleted
|
|
||||||
deleted_attachment = (
|
|
||||||
db_session.query(EmailAttachment).filter_by(id=attachment_id).first()
|
|
||||||
)
|
|
||||||
assert deleted_attachment is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_subclass_deletion_cascades_to_source_item(db_session: Session):
|
def test_subclass_deletion_cascades_to_source_item(db_session: Session):
|
||||||
mail_message = MailMessage(
|
mail_message = MailMessage(
|
||||||
sha256=b"test_email_cascade",
|
sha256=b"test_email_cascade",
|
||||||
@ -926,173 +646,3 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session):
|
|||||||
# Verify both the MailMessage and SourceItem records are deleted
|
# Verify both the MailMessage and SourceItem records are deleted
|
||||||
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
|
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
|
||||||
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
|
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"pages,expected_chunks",
|
|
||||||
[
|
|
||||||
# No pages
|
|
||||||
([], []),
|
|
||||||
# Single page
|
|
||||||
(["Page 1 content"], [("Page 1 content", {"type": "page"})]),
|
|
||||||
# Multiple pages
|
|
||||||
(
|
|
||||||
["Page 1", "Page 2", "Page 3"],
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"Page 1\n\nPage 2\n\nPage 3",
|
|
||||||
{"type": "section", "tags": {"tag1", "tag2"}},
|
|
||||||
),
|
|
||||||
("test", {"type": "summary", "tags": {"tag1", "tag2"}}),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
# Empty/whitespace pages filtered out
|
|
||||||
(["", " ", "Page 3"], [("Page 3", {"type": "page"})]),
|
|
||||||
# All empty - no chunks created
|
|
||||||
(["", " ", " "], []),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_book_section_data_chunks(pages, expected_chunks):
|
|
||||||
"""Test BookSection.data_chunks with various page combinations"""
|
|
||||||
content = "\n\n".join(pages).strip()
|
|
||||||
book_section = BookSection(
|
|
||||||
sha256=b"test_section",
|
|
||||||
content=content,
|
|
||||||
modality="book",
|
|
||||||
book_id=1,
|
|
||||||
start_page=10,
|
|
||||||
end_page=10 + len(pages),
|
|
||||||
pages=pages,
|
|
||||||
book=Book(id=1, title="Test Book", author="Test Author"),
|
|
||||||
)
|
|
||||||
|
|
||||||
chunks = book_section.data_chunks()
|
|
||||||
expected = [
|
|
||||||
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
|
|
||||||
]
|
|
||||||
assert [(c.content, c.item_metadata) for c in chunks] == expected
|
|
||||||
for c in chunks:
|
|
||||||
assert cast(list, c.file_paths) == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"content,expected",
|
|
||||||
[
|
|
||||||
("", []),
|
|
||||||
(
|
|
||||||
"Short content",
|
|
||||||
[
|
|
||||||
extract.DataChunk(
|
|
||||||
data=["Short content"], metadata={"tags": ["tag1", "tag2"]}
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
|
|
||||||
[
|
|
||||||
extract.DataChunk(
|
|
||||||
data=[
|
|
||||||
"This is a very long piece of content that should be chunked into multiple pieces when processed."
|
|
||||||
],
|
|
||||||
metadata={"tags": ["tag1", "tag2"]},
|
|
||||||
),
|
|
||||||
extract.DataChunk(
|
|
||||||
data=["This is a very long piece of content that"],
|
|
||||||
metadata={"tags": ["tag1", "tag2"]},
|
|
||||||
),
|
|
||||||
extract.DataChunk(
|
|
||||||
data=["should be chunked into multiple pieces when"],
|
|
||||||
metadata={"tags": ["tag1", "tag2"]},
|
|
||||||
),
|
|
||||||
extract.DataChunk(
|
|
||||||
data=["processed."],
|
|
||||||
metadata={"tags": ["tag1", "tag2"]},
|
|
||||||
),
|
|
||||||
extract.DataChunk(
|
|
||||||
data=["test"],
|
|
||||||
metadata={"tags": ["tag1", "tag2"]},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_blog_post_chunk_contents(content, expected, default_chunk_size):
|
|
||||||
default_chunk_size(10)
|
|
||||||
blog_post = BlogPost(
|
|
||||||
sha256=b"test_blog",
|
|
||||||
content=content,
|
|
||||||
modality="blog",
|
|
||||||
url="https://example.com/post",
|
|
||||||
images=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
|
|
||||||
assert blog_post._chunk_contents() == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_blog_post_chunk_contents_with_images(tmp_path):
|
|
||||||
"""Test BlogPost._chunk_contents with images"""
|
|
||||||
# Create test image files
|
|
||||||
img1_path = tmp_path / "img1.jpg"
|
|
||||||
img2_path = tmp_path / "img2.jpg"
|
|
||||||
for img_path in [img1_path, img2_path]:
|
|
||||||
img = Image.new("RGB", (10, 10), color="red")
|
|
||||||
img.save(img_path)
|
|
||||||
|
|
||||||
blog_post = BlogPost(
|
|
||||||
sha256=b"test_blog",
|
|
||||||
content="Content with images",
|
|
||||||
modality="blog",
|
|
||||||
url="https://example.com/post",
|
|
||||||
images=[str(img1_path), str(img2_path)],
|
|
||||||
)
|
|
||||||
|
|
||||||
result = blog_post._chunk_contents()
|
|
||||||
result = [
|
|
||||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
|
||||||
for c in result
|
|
||||||
]
|
|
||||||
assert result == [
|
|
||||||
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chunk_size):
|
|
||||||
default_chunk_size(10)
|
|
||||||
img1_path = tmp_path / "img1.jpg"
|
|
||||||
img2_path = tmp_path / "img2.jpg"
|
|
||||||
for img_path in [img1_path, img2_path]:
|
|
||||||
img = Image.new("RGB", (10, 10), color="red")
|
|
||||||
img.save(img_path)
|
|
||||||
|
|
||||||
blog_post = BlogPost(
|
|
||||||
sha256=b"test_blog",
|
|
||||||
content=f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
|
||||||
modality="blog",
|
|
||||||
url="https://example.com/post",
|
|
||||||
images=[str(img1_path), str(img2_path)],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
|
|
||||||
result = blog_post._chunk_contents()
|
|
||||||
|
|
||||||
result = [
|
|
||||||
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
|
||||||
for c in result
|
|
||||||
]
|
|
||||||
assert result == [
|
|
||||||
[
|
|
||||||
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
|
||||||
img1_path.as_posix(),
|
|
||||||
img2_path.as_posix(),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
f"First picture is here: {img1_path.as_posix()}",
|
|
||||||
img1_path.as_posix(),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
f"Second picture is here: {img2_path.as_posix()}",
|
|
||||||
img2_path.as_posix(),
|
|
||||||
],
|
|
||||||
["test"],
|
|
||||||
]
|
|
607
tests/memory/common/db/models/test_source_items.py
Normal file
607
tests/memory/common/db/models/test_source_items.py
Normal file
@ -0,0 +1,607 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
from typing import cast
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
from datetime import datetime
|
||||||
|
import uuid
|
||||||
|
from memory.common import settings, chunker, extract
|
||||||
|
from memory.common.db.models.sources import Book
|
||||||
|
from memory.common.db.models.source_items import (
|
||||||
|
MailMessage,
|
||||||
|
EmailAttachment,
|
||||||
|
BookSection,
|
||||||
|
BlogPost,
|
||||||
|
AgentObservation,
|
||||||
|
)
|
||||||
|
from memory.common.db.models.source_item import merge_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_chunk_size():
|
||||||
|
chunk_length = chunker.DEFAULT_CHUNK_TOKENS
|
||||||
|
real_chunker = chunker.chunk_text
|
||||||
|
|
||||||
|
def chunk_text(text: str, max_tokens: int = 0):
|
||||||
|
max_tokens = max_tokens or chunk_length
|
||||||
|
return real_chunker(text, max_tokens=max_tokens)
|
||||||
|
|
||||||
|
def set_size(new_size: int):
|
||||||
|
nonlocal chunk_length
|
||||||
|
chunk_length = new_size
|
||||||
|
|
||||||
|
with patch.object(chunker, "chunk_text", chunk_text):
|
||||||
|
yield set_size
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"modality,expected_modality",
|
||||||
|
[
|
||||||
|
(None, "email"), # Default case
|
||||||
|
("custom", "custom"), # Override case
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mail_message_modality(modality, expected_modality):
|
||||||
|
"""Test MailMessage modality setting"""
|
||||||
|
kwargs = {"sha256": b"test", "content": "test"}
|
||||||
|
if modality is not None:
|
||||||
|
kwargs["modality"] = modality
|
||||||
|
|
||||||
|
mail_message = MailMessage(**kwargs)
|
||||||
|
# The __init__ method should set the correct modality
|
||||||
|
assert hasattr(mail_message, "modality")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sender,folder,expected_path",
|
||||||
|
[
|
||||||
|
("user@example.com", "INBOX", "user_example_com/INBOX"),
|
||||||
|
("user+tag@example.com", "Sent Items", "user_tag_example_com/Sent_Items"),
|
||||||
|
("user@domain.co.uk", None, "user_domain_co_uk/INBOX"),
|
||||||
|
("user@domain.co.uk", "", "user_domain_co_uk/INBOX"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mail_message_attachments_path(sender, folder, expected_path):
|
||||||
|
"""Test MailMessage.attachments_path property"""
|
||||||
|
mail_message = MailMessage(
|
||||||
|
sha256=b"test", content="test", sender=sender, folder=folder
|
||||||
|
)
|
||||||
|
|
||||||
|
result = mail_message.attachments_path
|
||||||
|
assert str(result) == f"{settings.FILE_STORAGE_DIR}/emails/{expected_path}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"filename,expected",
|
||||||
|
[
|
||||||
|
("document.pdf", "document.pdf"),
|
||||||
|
("file with spaces.txt", "file_with_spaces.txt"),
|
||||||
|
("file@#$%^&*().doc", "file.doc"),
|
||||||
|
("no-extension", "no_extension"),
|
||||||
|
("multiple.dots.in.name.txt", "multiple_dots_in_name.txt"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mail_message_safe_filename(tmp_path, filename, expected):
|
||||||
|
"""Test MailMessage.safe_filename method"""
|
||||||
|
mail_message = MailMessage(
|
||||||
|
sha256=b"test", content="test", sender="user@example.com", folder="INBOX"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = settings.FILE_STORAGE_DIR / f"emails/user_example_com/INBOX/{expected}"
|
||||||
|
assert mail_message.safe_filename(filename) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sent_at,expected_date",
|
||||||
|
[
|
||||||
|
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
|
||||||
|
(None, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mail_message_as_payload(sent_at, expected_date):
|
||||||
|
"""Test MailMessage.as_payload method"""
|
||||||
|
|
||||||
|
mail_message = MailMessage(
|
||||||
|
sha256=b"test",
|
||||||
|
content="test",
|
||||||
|
message_id="<test@example.com>",
|
||||||
|
subject="Test Subject",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient1@example.com", "recipient2@example.com"],
|
||||||
|
folder="INBOX",
|
||||||
|
sent_at=sent_at,
|
||||||
|
tags=["tag1", "tag2"],
|
||||||
|
size=1024,
|
||||||
|
)
|
||||||
|
# Manually set id for testing
|
||||||
|
object.__setattr__(mail_message, "id", 123)
|
||||||
|
|
||||||
|
payload = mail_message.as_payload()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"source_id": 123,
|
||||||
|
"size": 1024,
|
||||||
|
"message_id": "<test@example.com>",
|
||||||
|
"subject": "Test Subject",
|
||||||
|
"sender": "sender@example.com",
|
||||||
|
"recipients": ["recipient1@example.com", "recipient2@example.com"],
|
||||||
|
"folder": "INBOX",
|
||||||
|
"tags": [
|
||||||
|
"tag1",
|
||||||
|
"tag2",
|
||||||
|
"sender@example.com",
|
||||||
|
"recipient1@example.com",
|
||||||
|
"recipient2@example.com",
|
||||||
|
],
|
||||||
|
"date": expected_date,
|
||||||
|
}
|
||||||
|
assert payload == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_mail_message_parsed_content():
|
||||||
|
"""Test MailMessage.parsed_content property with actual email parsing"""
|
||||||
|
# Use a simple email format that the parser can handle
|
||||||
|
email_content = """From: sender@example.com
|
||||||
|
To: recipient@example.com
|
||||||
|
Subject: Test Subject
|
||||||
|
|
||||||
|
Test Body Content"""
|
||||||
|
|
||||||
|
mail_message = MailMessage(
|
||||||
|
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = mail_message.parsed_content
|
||||||
|
|
||||||
|
# Just test that it returns a dict-like object
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "body" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_mail_message_body_property():
|
||||||
|
"""Test MailMessage.body property with actual email parsing"""
|
||||||
|
email_content = """From: sender@example.com
|
||||||
|
To: recipient@example.com
|
||||||
|
Subject: Test Subject
|
||||||
|
|
||||||
|
Test Body Content"""
|
||||||
|
|
||||||
|
mail_message = MailMessage(
|
||||||
|
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mail_message.body == "Test Body Content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mail_message_display_contents():
|
||||||
|
"""Test MailMessage.display_contents property with actual email parsing"""
|
||||||
|
email_content = """From: sender@example.com
|
||||||
|
To: recipient@example.com
|
||||||
|
Subject: Test Subject
|
||||||
|
|
||||||
|
Test Body Content"""
|
||||||
|
|
||||||
|
mail_message = MailMessage(
|
||||||
|
sha256=b"test", content=email_content, message_id="<test@example.com>"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = (
|
||||||
|
"\nSubject: Test Subject\nFrom: \nTo: \nDate: \nBody: \nTest Body Content\n"
|
||||||
|
)
|
||||||
|
assert mail_message.display_contents == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"created_at,expected_date",
|
||||||
|
[
|
||||||
|
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
|
||||||
|
(None, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_email_attachment_as_payload(created_at, expected_date):
|
||||||
|
"""Test EmailAttachment.as_payload method"""
|
||||||
|
attachment = EmailAttachment(
|
||||||
|
sha256=b"test",
|
||||||
|
filename="document.pdf",
|
||||||
|
mime_type="application/pdf",
|
||||||
|
size=1024,
|
||||||
|
mail_message_id=123,
|
||||||
|
created_at=created_at,
|
||||||
|
tags=["pdf", "document"],
|
||||||
|
)
|
||||||
|
# Manually set id for testing
|
||||||
|
object.__setattr__(attachment, "id", 456)
|
||||||
|
|
||||||
|
payload = attachment.as_payload()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"source_id": 456,
|
||||||
|
"filename": "document.pdf",
|
||||||
|
"content_type": "application/pdf",
|
||||||
|
"size": 1024,
|
||||||
|
"created_at": expected_date,
|
||||||
|
"mail_message_id": 123,
|
||||||
|
"tags": ["pdf", "document"],
|
||||||
|
}
|
||||||
|
assert payload == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"has_filename,content_source,expected_content",
|
||||||
|
[
|
||||||
|
(True, "file", b"test file content"),
|
||||||
|
(False, "content", "attachment content"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("memory.common.extract.extract_data_chunks")
|
||||||
|
def test_email_attachment_data_chunks(
|
||||||
|
mock_extract, has_filename, content_source, expected_content, tmp_path
|
||||||
|
):
|
||||||
|
"""Test EmailAttachment.data_chunks method"""
|
||||||
|
from memory.common.extract import DataChunk
|
||||||
|
|
||||||
|
mock_extract.return_value = [
|
||||||
|
DataChunk(data=["extracted text"], metadata={"source": content_source})
|
||||||
|
]
|
||||||
|
|
||||||
|
if has_filename:
|
||||||
|
# Create a test file
|
||||||
|
test_file = tmp_path / "test.txt"
|
||||||
|
test_file.write_bytes(b"test file content")
|
||||||
|
attachment = EmailAttachment(
|
||||||
|
sha256=b"test",
|
||||||
|
filename=str(test_file),
|
||||||
|
mime_type="text/plain",
|
||||||
|
mail_message_id=123,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attachment = EmailAttachment(
|
||||||
|
sha256=b"test",
|
||||||
|
content="attachment content",
|
||||||
|
filename=None,
|
||||||
|
mime_type="text/plain",
|
||||||
|
mail_message_id=123,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock _make_chunk to return a simple chunk
|
||||||
|
mock_chunk = Mock()
|
||||||
|
with patch.object(attachment, "_make_chunk", return_value=mock_chunk) as mock_make:
|
||||||
|
result = attachment.data_chunks({"extra": "metadata"})
|
||||||
|
|
||||||
|
# Verify the method calls
|
||||||
|
mock_extract.assert_called_once_with("text/plain", expected_content)
|
||||||
|
mock_make.assert_called_once_with(
|
||||||
|
extract.DataChunk(data=["extracted text"], metadata={"source": content_source}),
|
||||||
|
{"extra": "metadata"},
|
||||||
|
)
|
||||||
|
assert result == [mock_chunk]
|
||||||
|
|
||||||
|
|
||||||
|
def test_email_attachment_cascade_delete(db_session: Session):
|
||||||
|
"""Test that EmailAttachment is deleted when MailMessage is deleted"""
|
||||||
|
mail_message = MailMessage(
|
||||||
|
sha256=b"test_email",
|
||||||
|
content="test email",
|
||||||
|
message_id="<test@example.com>",
|
||||||
|
subject="Test",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient@example.com"],
|
||||||
|
folder="INBOX",
|
||||||
|
)
|
||||||
|
db_session.add(mail_message)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
attachment = EmailAttachment(
|
||||||
|
sha256=b"test_attachment",
|
||||||
|
content="attachment content",
|
||||||
|
mail_message=mail_message,
|
||||||
|
filename="test.txt",
|
||||||
|
mime_type="text/plain",
|
||||||
|
size=100,
|
||||||
|
modality="attachment", # Set modality explicitly
|
||||||
|
)
|
||||||
|
db_session.add(attachment)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
attachment_id = attachment.id
|
||||||
|
|
||||||
|
# Delete the mail message
|
||||||
|
db_session.delete(mail_message)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Verify the attachment was also deleted
|
||||||
|
deleted_attachment = (
|
||||||
|
db_session.query(EmailAttachment).filter_by(id=attachment_id).first()
|
||||||
|
)
|
||||||
|
assert deleted_attachment is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pages,expected_chunks",
|
||||||
|
[
|
||||||
|
# No pages
|
||||||
|
([], []),
|
||||||
|
# Single page
|
||||||
|
(["Page 1 content"], [("Page 1 content", {"type": "page"})]),
|
||||||
|
# Multiple pages
|
||||||
|
(
|
||||||
|
["Page 1", "Page 2", "Page 3"],
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"Page 1\n\nPage 2\n\nPage 3",
|
||||||
|
{"type": "section", "tags": {"tag1", "tag2"}},
|
||||||
|
),
|
||||||
|
("test", {"type": "summary", "tags": {"tag1", "tag2"}}),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# Empty/whitespace pages filtered out
|
||||||
|
(["", " ", "Page 3"], [("Page 3", {"type": "page"})]),
|
||||||
|
# All empty - no chunks created
|
||||||
|
(["", " ", " "], []),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_book_section_data_chunks(pages, expected_chunks):
|
||||||
|
"""Test BookSection.data_chunks with various page combinations"""
|
||||||
|
content = "\n\n".join(pages).strip()
|
||||||
|
book_section = BookSection(
|
||||||
|
sha256=b"test_section",
|
||||||
|
content=content,
|
||||||
|
modality="book",
|
||||||
|
book_id=1,
|
||||||
|
start_page=10,
|
||||||
|
end_page=10 + len(pages),
|
||||||
|
pages=pages,
|
||||||
|
book=Book(id=1, title="Test Book", author="Test Author"),
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = book_section.data_chunks()
|
||||||
|
expected = [
|
||||||
|
(c, merge_metadata(book_section.as_payload(), m)) for c, m in expected_chunks
|
||||||
|
]
|
||||||
|
assert [(c.content, c.item_metadata) for c in chunks] == expected
|
||||||
|
for c in chunks:
|
||||||
|
assert cast(list, c.file_paths) == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content,expected",
|
||||||
|
[
|
||||||
|
("", []),
|
||||||
|
(
|
||||||
|
"Short content",
|
||||||
|
[
|
||||||
|
extract.DataChunk(
|
||||||
|
data=["Short content"], metadata={"tags": ["tag1", "tag2"]}
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
|
||||||
|
[
|
||||||
|
extract.DataChunk(
|
||||||
|
data=[
|
||||||
|
"This is a very long piece of content that should be chunked into multiple pieces when processed."
|
||||||
|
],
|
||||||
|
metadata={"tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
extract.DataChunk(
|
||||||
|
data=["This is a very long piece of content that"],
|
||||||
|
metadata={"tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
extract.DataChunk(
|
||||||
|
data=["should be chunked into multiple pieces when"],
|
||||||
|
metadata={"tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
extract.DataChunk(
|
||||||
|
data=["processed."],
|
||||||
|
metadata={"tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
extract.DataChunk(
|
||||||
|
data=["test"],
|
||||||
|
metadata={"tags": ["tag1", "tag2"]},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_blog_post_chunk_contents(content, expected, default_chunk_size):
|
||||||
|
default_chunk_size(10)
|
||||||
|
blog_post = BlogPost(
|
||||||
|
sha256=b"test_blog",
|
||||||
|
content=content,
|
||||||
|
modality="blog",
|
||||||
|
url="https://example.com/post",
|
||||||
|
images=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
|
||||||
|
assert blog_post._chunk_contents() == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_blog_post_chunk_contents_with_images(tmp_path):
|
||||||
|
"""Test BlogPost._chunk_contents with images"""
|
||||||
|
# Create test image files
|
||||||
|
img1_path = tmp_path / "img1.jpg"
|
||||||
|
img2_path = tmp_path / "img2.jpg"
|
||||||
|
for img_path in [img1_path, img2_path]:
|
||||||
|
img = Image.new("RGB", (10, 10), color="red")
|
||||||
|
img.save(img_path)
|
||||||
|
|
||||||
|
blog_post = BlogPost(
|
||||||
|
sha256=b"test_blog",
|
||||||
|
content="Content with images",
|
||||||
|
modality="blog",
|
||||||
|
url="https://example.com/post",
|
||||||
|
images=[str(img1_path), str(img2_path)],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = blog_post._chunk_contents()
|
||||||
|
result = [
|
||||||
|
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
||||||
|
for c in result
|
||||||
|
]
|
||||||
|
assert result == [
|
||||||
|
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chunk_size):
|
||||||
|
default_chunk_size(10)
|
||||||
|
img1_path = tmp_path / "img1.jpg"
|
||||||
|
img2_path = tmp_path / "img2.jpg"
|
||||||
|
for img_path in [img1_path, img2_path]:
|
||||||
|
img = Image.new("RGB", (10, 10), color="red")
|
||||||
|
img.save(img_path)
|
||||||
|
|
||||||
|
blog_post = BlogPost(
|
||||||
|
sha256=b"test_blog",
|
||||||
|
content=f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
||||||
|
modality="blog",
|
||||||
|
url="https://example.com/post",
|
||||||
|
images=[str(img1_path), str(img2_path)],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
|
||||||
|
result = blog_post._chunk_contents()
|
||||||
|
|
||||||
|
result = [
|
||||||
|
[i if isinstance(i, str) else getattr(i, "filename") for i in c.data]
|
||||||
|
for c in result
|
||||||
|
]
|
||||||
|
assert result == [
|
||||||
|
[
|
||||||
|
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
|
||||||
|
img1_path.as_posix(),
|
||||||
|
img2_path.as_posix(),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
f"First picture is here: {img1_path.as_posix()}",
|
||||||
|
img1_path.as_posix(),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
f"Second picture is here: {img2_path.as_posix()}",
|
||||||
|
img2_path.as_posix(),
|
||||||
|
],
|
||||||
|
["test"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"metadata,expected_semantic_metadata,expected_temporal_metadata,observation_tags",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
{"embedding_type": "semantic"},
|
||||||
|
{"embedding_type": "temporal"},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"extra_key": "extra_value"},
|
||||||
|
{"extra_key": "extra_value", "embedding_type": "semantic"},
|
||||||
|
{"extra_key": "extra_value", "embedding_type": "temporal"},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"tags": ["existing_tag"], "source": "test"},
|
||||||
|
{"tags": {"existing_tag"}, "source": "test", "embedding_type": "semantic"},
|
||||||
|
{"tags": {"existing_tag"}, "source": "test", "embedding_type": "temporal"},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_agent_observation_data_chunks(
|
||||||
|
metadata, expected_semantic_metadata, expected_temporal_metadata, observation_tags
|
||||||
|
):
|
||||||
|
"""Test AgentObservation.data_chunks generates correct chunks with proper metadata"""
|
||||||
|
observation = AgentObservation(
|
||||||
|
sha256=b"test_obs",
|
||||||
|
content="User prefers Python over JavaScript",
|
||||||
|
subject="programming preferences",
|
||||||
|
observation_type="preference",
|
||||||
|
confidence=0.9,
|
||||||
|
evidence={
|
||||||
|
"quote": "I really like Python",
|
||||||
|
"context": "discussion about languages",
|
||||||
|
},
|
||||||
|
agent_model="claude-3.5-sonnet",
|
||||||
|
session_id=uuid.uuid4(),
|
||||||
|
tags=observation_tags,
|
||||||
|
)
|
||||||
|
# Set inserted_at using object.__setattr__ to bypass SQLAlchemy restrictions
|
||||||
|
object.__setattr__(observation, "inserted_at", datetime(2023, 1, 1, 12, 0, 0))
|
||||||
|
|
||||||
|
result = observation.data_chunks(metadata)
|
||||||
|
|
||||||
|
# Verify chunks
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
semantic_chunk = result[0]
|
||||||
|
expected_semantic_text = "Subject: programming preferences | Type: preference | Observation: User prefers Python over JavaScript | Quote: I really like Python | Context: discussion about languages"
|
||||||
|
assert semantic_chunk.data == [expected_semantic_text]
|
||||||
|
assert semantic_chunk.metadata == expected_semantic_metadata
|
||||||
|
assert semantic_chunk.collection_name == "semantic"
|
||||||
|
|
||||||
|
temporal_chunk = result[1]
|
||||||
|
expected_temporal_text = "Time: 12:00 on Sunday (afternoon) | Subject: programming preferences | Observation: User prefers Python over JavaScript | Confidence: 0.9"
|
||||||
|
assert temporal_chunk.data == [expected_temporal_text]
|
||||||
|
assert temporal_chunk.metadata == expected_temporal_metadata
|
||||||
|
assert temporal_chunk.collection_name == "temporal"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_observation_data_chunks_with_none_values():
|
||||||
|
"""Test AgentObservation.data_chunks handles None values correctly"""
|
||||||
|
observation = AgentObservation(
|
||||||
|
sha256=b"test_obs",
|
||||||
|
content="Content",
|
||||||
|
subject="subject",
|
||||||
|
observation_type="belief",
|
||||||
|
confidence=0.7,
|
||||||
|
evidence=None,
|
||||||
|
agent_model="gpt-4",
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
object.__setattr__(observation, "inserted_at", datetime(2023, 2, 15, 9, 30, 0))
|
||||||
|
|
||||||
|
result = observation.data_chunks()
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0].collection_name == "semantic"
|
||||||
|
assert result[1].collection_name == "temporal"
|
||||||
|
|
||||||
|
# Verify content with None evidence
|
||||||
|
semantic_text = "Subject: subject | Type: belief | Observation: Content"
|
||||||
|
assert result[0].data == [semantic_text]
|
||||||
|
|
||||||
|
temporal_text = "Time: 09:30 on Wednesday (morning) | Subject: subject | Observation: Content | Confidence: 0.7"
|
||||||
|
assert result[1].data == [temporal_text]
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_observation_data_chunks_merge_metadata_behavior():
|
||||||
|
"""Test that merge_metadata works correctly in data_chunks"""
|
||||||
|
observation = AgentObservation(
|
||||||
|
sha256=b"test",
|
||||||
|
content="test",
|
||||||
|
subject="test",
|
||||||
|
observation_type="test",
|
||||||
|
confidence=0.8,
|
||||||
|
evidence={},
|
||||||
|
agent_model="test",
|
||||||
|
tags=["base_tag"], # Set base tags so they appear in both chunks
|
||||||
|
)
|
||||||
|
object.__setattr__(observation, "inserted_at", datetime.now())
|
||||||
|
|
||||||
|
# Test that metadata merging preserves original values and adds new ones
|
||||||
|
input_metadata = {"existing": "value", "tags": ["tag1"]}
|
||||||
|
result = observation.data_chunks(input_metadata)
|
||||||
|
|
||||||
|
semantic_metadata = result[0].metadata
|
||||||
|
temporal_metadata = result[1].metadata
|
||||||
|
|
||||||
|
# Both should have the existing metadata plus embedding_type
|
||||||
|
assert semantic_metadata["existing"] == "value"
|
||||||
|
assert semantic_metadata["tags"] == {"tag1"} # Merged tags
|
||||||
|
assert semantic_metadata["embedding_type"] == "semantic"
|
||||||
|
|
||||||
|
assert temporal_metadata["existing"] == "value"
|
||||||
|
assert temporal_metadata["tags"] == {"tag1"} # Merged tags
|
||||||
|
assert temporal_metadata["embedding_type"] == "temporal"
|
220
tests/memory/common/formatters/test_observation.py
Normal file
220
tests/memory/common/formatters/test_observation.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from memory.common.formatters.observation import (
|
||||||
|
Evidence,
|
||||||
|
generate_semantic_text,
|
||||||
|
generate_temporal_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_semantic_text_basic_functionality():
|
||||||
|
evidence: Evidence = {"quote": "test quote", "context": "test context"}
|
||||||
|
result = generate_semantic_text(
|
||||||
|
subject="test_subject",
|
||||||
|
observation_type="test_type",
|
||||||
|
content="test_content",
|
||||||
|
evidence=evidence,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result
|
||||||
|
== "Subject: test_subject | Type: test_type | Observation: test_content | Quote: test quote | Context: test context"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"evidence,expected_suffix",
|
||||||
|
[
|
||||||
|
({"quote": "test quote"}, " | Quote: test quote"),
|
||||||
|
({"context": "test context"}, " | Context: test context"),
|
||||||
|
({}, ""),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_semantic_text_partial_evidence(
|
||||||
|
evidence: dict[str, str], expected_suffix: str
|
||||||
|
):
|
||||||
|
result = generate_semantic_text(
|
||||||
|
subject="subject",
|
||||||
|
observation_type="type",
|
||||||
|
content="content",
|
||||||
|
evidence=evidence, # type: ignore
|
||||||
|
)
|
||||||
|
expected = f"Subject: subject | Type: type | Observation: content{expected_suffix}"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_semantic_text_none_evidence():
|
||||||
|
result = generate_semantic_text(
|
||||||
|
subject="subject",
|
||||||
|
observation_type="type",
|
||||||
|
content="content",
|
||||||
|
evidence=None, # type: ignore
|
||||||
|
)
|
||||||
|
assert result == "Subject: subject | Type: type | Observation: content"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_evidence",
|
||||||
|
[
|
||||||
|
"string",
|
||||||
|
123,
|
||||||
|
["list"],
|
||||||
|
True,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_semantic_text_invalid_evidence_types(invalid_evidence: Any):
|
||||||
|
result = generate_semantic_text(
|
||||||
|
subject="subject",
|
||||||
|
observation_type="type",
|
||||||
|
content="content",
|
||||||
|
evidence=invalid_evidence, # type: ignore
|
||||||
|
)
|
||||||
|
assert result == "Subject: subject | Type: type | Observation: content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_semantic_text_empty_strings():
|
||||||
|
evidence = {"quote": "", "context": ""}
|
||||||
|
result = generate_semantic_text(
|
||||||
|
subject="",
|
||||||
|
observation_type="",
|
||||||
|
content="",
|
||||||
|
evidence=evidence, # type: ignore
|
||||||
|
)
|
||||||
|
assert result == "Subject: | Type: | Observation: | Quote: | Context: "
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_semantic_text_special_characters():
|
||||||
|
evidence: Evidence = {
|
||||||
|
"quote": "Quote with | pipe and | symbols",
|
||||||
|
"context": "Context with special chars: @#$%",
|
||||||
|
}
|
||||||
|
result = generate_semantic_text(
|
||||||
|
subject="Subject with | pipe",
|
||||||
|
observation_type="Type with | pipe",
|
||||||
|
content="Content with | pipe",
|
||||||
|
evidence=evidence,
|
||||||
|
)
|
||||||
|
expected = "Subject: Subject with | pipe | Type: Type with | pipe | Observation: Content with | pipe | Quote: Quote with | pipe and | symbols | Context: Context with special chars: @#$%"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"hour,expected_period",
|
||||||
|
[
|
||||||
|
(5, "morning"),
|
||||||
|
(6, "morning"),
|
||||||
|
(11, "morning"),
|
||||||
|
(12, "afternoon"),
|
||||||
|
(13, "afternoon"),
|
||||||
|
(16, "afternoon"),
|
||||||
|
(17, "evening"),
|
||||||
|
(18, "evening"),
|
||||||
|
(21, "evening"),
|
||||||
|
(22, "late_night"),
|
||||||
|
(23, "late_night"),
|
||||||
|
(0, "late_night"),
|
||||||
|
(1, "late_night"),
|
||||||
|
(4, "late_night"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_temporal_text_time_periods(hour: int, expected_period: str):
|
||||||
|
test_date = datetime(2024, 1, 15, hour, 30) # Monday
|
||||||
|
result = generate_temporal_text(
|
||||||
|
subject="test_subject",
|
||||||
|
content="test_content",
|
||||||
|
confidence=0.8,
|
||||||
|
created_at=test_date,
|
||||||
|
)
|
||||||
|
time_str = test_date.strftime("%H:%M")
|
||||||
|
expected = f"Time: {time_str} on Monday ({expected_period}) | Subject: test_subject | Observation: test_content | Confidence: 0.8"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"weekday,day_name",
|
||||||
|
[
|
||||||
|
(0, "Monday"),
|
||||||
|
(1, "Tuesday"),
|
||||||
|
(2, "Wednesday"),
|
||||||
|
(3, "Thursday"),
|
||||||
|
(4, "Friday"),
|
||||||
|
(5, "Saturday"),
|
||||||
|
(6, "Sunday"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_temporal_text_days_of_week(weekday: int, day_name: str):
|
||||||
|
test_date = datetime(2024, 1, 15 + weekday, 10, 30)
|
||||||
|
result = generate_temporal_text(
|
||||||
|
subject="subject", content="content", confidence=0.5, created_at=test_date
|
||||||
|
)
|
||||||
|
assert f"on {day_name}" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("confidence", [0.0, 0.1, 0.5, 0.99, 1.0])
|
||||||
|
def test_generate_temporal_text_confidence_values(confidence: float):
|
||||||
|
test_date = datetime(2024, 1, 15, 10, 30)
|
||||||
|
result = generate_temporal_text(
|
||||||
|
subject="subject",
|
||||||
|
content="content",
|
||||||
|
confidence=confidence,
|
||||||
|
created_at=test_date,
|
||||||
|
)
|
||||||
|
assert f"Confidence: {confidence}" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_date,expected_period",
|
||||||
|
[
|
||||||
|
(datetime(2024, 1, 15, 5, 0), "morning"), # Start of morning
|
||||||
|
(datetime(2024, 1, 15, 11, 59), "morning"), # End of morning
|
||||||
|
(datetime(2024, 1, 15, 12, 0), "afternoon"), # Start of afternoon
|
||||||
|
(datetime(2024, 1, 15, 16, 59), "afternoon"), # End of afternoon
|
||||||
|
(datetime(2024, 1, 15, 17, 0), "evening"), # Start of evening
|
||||||
|
(datetime(2024, 1, 15, 21, 59), "evening"), # End of evening
|
||||||
|
(datetime(2024, 1, 15, 22, 0), "late_night"), # Start of late_night
|
||||||
|
(datetime(2024, 1, 15, 4, 59), "late_night"), # End of late_night
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_temporal_text_boundary_cases(
|
||||||
|
test_date: datetime, expected_period: str
|
||||||
|
):
|
||||||
|
result = generate_temporal_text(
|
||||||
|
subject="subject", content="content", confidence=0.8, created_at=test_date
|
||||||
|
)
|
||||||
|
assert f"({expected_period})" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_temporal_text_complete_format():
|
||||||
|
test_date = datetime(2024, 3, 22, 14, 45) # Friday afternoon
|
||||||
|
result = generate_temporal_text(
|
||||||
|
subject="Important observation",
|
||||||
|
content="User showed strong preference for X",
|
||||||
|
confidence=0.95,
|
||||||
|
created_at=test_date,
|
||||||
|
)
|
||||||
|
expected = "Time: 14:45 on Friday (afternoon) | Subject: Important observation | Observation: User showed strong preference for X | Confidence: 0.95"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_temporal_text_empty_strings():
|
||||||
|
test_date = datetime(2024, 1, 15, 10, 30)
|
||||||
|
result = generate_temporal_text(
|
||||||
|
subject="", content="", confidence=0.0, created_at=test_date
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result
|
||||||
|
== "Time: 10:30 on Monday (morning) | Subject: | Observation: | Confidence: 0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_temporal_text_special_characters():
|
||||||
|
test_date = datetime(2024, 1, 15, 15, 20)
|
||||||
|
result = generate_temporal_text(
|
||||||
|
subject="Subject with | pipe",
|
||||||
|
content="Content with | pipe and @#$ symbols",
|
||||||
|
confidence=0.75,
|
||||||
|
created_at=test_date,
|
||||||
|
)
|
||||||
|
expected = "Time: 15:20 on Monday (afternoon) | Subject: Subject with | pipe | Observation: Content with | pipe and @#$ symbols | Confidence: 0.75"
|
||||||
|
assert result == expected
|
@ -18,6 +18,7 @@ from memory.workers.tasks.content_processing import (
|
|||||||
process_content_item,
|
process_content_item,
|
||||||
push_to_qdrant,
|
push_to_qdrant,
|
||||||
safe_task_execution,
|
safe_task_execution,
|
||||||
|
by_collection,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -71,6 +72,7 @@ def mock_chunk():
|
|||||||
chunk.id = "00000000-0000-0000-0000-000000000001"
|
chunk.id = "00000000-0000-0000-0000-000000000001"
|
||||||
chunk.vector = [0.1] * 1024
|
chunk.vector = [0.1] * 1024
|
||||||
chunk.item_metadata = {"source_id": 1, "tags": ["test"]}
|
chunk.item_metadata = {"source_id": 1, "tags": ["test"]}
|
||||||
|
chunk.collection_name = "mail"
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
@ -242,17 +244,19 @@ def test_push_to_qdrant_success(qdrant):
|
|||||||
mock_chunk1.id = "00000000-0000-0000-0000-000000000001"
|
mock_chunk1.id = "00000000-0000-0000-0000-000000000001"
|
||||||
mock_chunk1.vector = [0.1] * 1024
|
mock_chunk1.vector = [0.1] * 1024
|
||||||
mock_chunk1.item_metadata = {"source_id": 1, "tags": ["test"]}
|
mock_chunk1.item_metadata = {"source_id": 1, "tags": ["test"]}
|
||||||
|
mock_chunk1.collection_name = "mail"
|
||||||
|
|
||||||
mock_chunk2 = MagicMock()
|
mock_chunk2 = MagicMock()
|
||||||
mock_chunk2.id = "00000000-0000-0000-0000-000000000002"
|
mock_chunk2.id = "00000000-0000-0000-0000-000000000002"
|
||||||
mock_chunk2.vector = [0.2] * 1024
|
mock_chunk2.vector = [0.2] * 1024
|
||||||
mock_chunk2.item_metadata = {"source_id": 2, "tags": ["test"]}
|
mock_chunk2.item_metadata = {"source_id": 2, "tags": ["test"]}
|
||||||
|
mock_chunk2.collection_name = "mail"
|
||||||
|
|
||||||
# Assign chunks directly (bypassing SQLAlchemy relationship)
|
# Assign chunks directly (bypassing SQLAlchemy relationship)
|
||||||
item1.chunks = [mock_chunk1]
|
item1.chunks = [mock_chunk1]
|
||||||
item2.chunks = [mock_chunk2]
|
item2.chunks = [mock_chunk2]
|
||||||
|
|
||||||
push_to_qdrant([item1, item2], "mail")
|
push_to_qdrant([item1, item2])
|
||||||
|
|
||||||
assert str(item1.embed_status) == "STORED"
|
assert str(item1.embed_status) == "STORED"
|
||||||
assert str(item2.embed_status) == "STORED"
|
assert str(item2.embed_status) == "STORED"
|
||||||
@ -294,6 +298,7 @@ def test_push_to_qdrant_no_processing(
|
|||||||
mock_chunk.id = f"00000000-0000-0000-0000-00000000000{suffix}"
|
mock_chunk.id = f"00000000-0000-0000-0000-00000000000{suffix}"
|
||||||
mock_chunk.vector = [0.1] * 1024
|
mock_chunk.vector = [0.1] * 1024
|
||||||
mock_chunk.item_metadata = {"source_id": int(suffix), "tags": ["test"]}
|
mock_chunk.item_metadata = {"source_id": int(suffix), "tags": ["test"]}
|
||||||
|
mock_chunk.collection_name = "mail"
|
||||||
item.chunks = [mock_chunk]
|
item.chunks = [mock_chunk]
|
||||||
else:
|
else:
|
||||||
item.chunks = []
|
item.chunks = []
|
||||||
@ -302,7 +307,7 @@ def test_push_to_qdrant_no_processing(
|
|||||||
item1 = create_item("1", item1_status, item1_has_chunks)
|
item1 = create_item("1", item1_status, item1_has_chunks)
|
||||||
item2 = create_item("2", item2_status, item2_has_chunks)
|
item2 = create_item("2", item2_status, item2_has_chunks)
|
||||||
|
|
||||||
push_to_qdrant([item1, item2], "mail")
|
push_to_qdrant([item1, item2])
|
||||||
|
|
||||||
assert str(item1.embed_status) == expected_item1_status
|
assert str(item1.embed_status) == expected_item1_status
|
||||||
assert str(item2.embed_status) == expected_item2_status
|
assert str(item2.embed_status) == expected_item2_status
|
||||||
@ -317,7 +322,7 @@ def test_push_to_qdrant_exception(sample_mail_message, mock_chunk):
|
|||||||
side_effect=Exception("Qdrant error"),
|
side_effect=Exception("Qdrant error"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception, match="Qdrant error"):
|
with pytest.raises(Exception, match="Qdrant error"):
|
||||||
push_to_qdrant([sample_mail_message], "mail")
|
push_to_qdrant([sample_mail_message])
|
||||||
|
|
||||||
assert str(sample_mail_message.embed_status) == "FAILED"
|
assert str(sample_mail_message.embed_status) == "FAILED"
|
||||||
|
|
||||||
@ -418,6 +423,196 @@ def test_create_task_result_no_title():
|
|||||||
assert result["chunks_count"] == 0
|
assert result["chunks_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_by_collection_empty_chunks():
|
||||||
|
result = by_collection([])
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_by_collection_single_chunk():
|
||||||
|
chunk = Chunk(
|
||||||
|
id="00000000-0000-0000-0000-000000000001",
|
||||||
|
content="test content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.1, 0.2, 0.3],
|
||||||
|
item_metadata={"source_id": 1, "tags": ["test"]},
|
||||||
|
collection_name="test_collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = by_collection([chunk])
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "test_collection" in result
|
||||||
|
assert result["test_collection"]["ids"] == ["00000000-0000-0000-0000-000000000001"]
|
||||||
|
assert result["test_collection"]["vectors"] == [[0.1, 0.2, 0.3]]
|
||||||
|
assert result["test_collection"]["payloads"] == [{"source_id": 1, "tags": ["test"]}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_by_collection_multiple_chunks_same_collection():
|
||||||
|
chunks = [
|
||||||
|
Chunk(
|
||||||
|
id="00000000-0000-0000-0000-000000000001",
|
||||||
|
content="test content 1",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.1, 0.2],
|
||||||
|
item_metadata={"source_id": 1},
|
||||||
|
collection_name="collection_a",
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
id="00000000-0000-0000-0000-000000000002",
|
||||||
|
content="test content 2",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.3, 0.4],
|
||||||
|
item_metadata={"source_id": 2},
|
||||||
|
collection_name="collection_a",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = by_collection(chunks)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "collection_a" in result
|
||||||
|
assert result["collection_a"]["ids"] == [
|
||||||
|
"00000000-0000-0000-0000-000000000001",
|
||||||
|
"00000000-0000-0000-0000-000000000002",
|
||||||
|
]
|
||||||
|
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.3, 0.4]]
|
||||||
|
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 2}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_by_collection_multiple_chunks_different_collections():
|
||||||
|
chunks = [
|
||||||
|
Chunk(
|
||||||
|
id="00000000-0000-0000-0000-000000000001",
|
||||||
|
content="test content 1",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.1, 0.2],
|
||||||
|
item_metadata={"source_id": 1},
|
||||||
|
collection_name="collection_a",
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
id="00000000-0000-0000-0000-000000000002",
|
||||||
|
content="test content 2",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.3, 0.4],
|
||||||
|
item_metadata={"source_id": 2},
|
||||||
|
collection_name="collection_b",
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
id="00000000-0000-0000-0000-000000000003",
|
||||||
|
content="test content 3",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.5, 0.6],
|
||||||
|
item_metadata={"source_id": 3},
|
||||||
|
collection_name="collection_a",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = by_collection(chunks)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "collection_a" in result
|
||||||
|
assert "collection_b" in result
|
||||||
|
|
||||||
|
# Check collection_a
|
||||||
|
assert result["collection_a"]["ids"] == [
|
||||||
|
"00000000-0000-0000-0000-000000000001",
|
||||||
|
"00000000-0000-0000-0000-000000000003",
|
||||||
|
]
|
||||||
|
assert result["collection_a"]["vectors"] == [[0.1, 0.2], [0.5, 0.6]]
|
||||||
|
assert result["collection_a"]["payloads"] == [{"source_id": 1}, {"source_id": 3}]
|
||||||
|
|
||||||
|
# Check collection_b
|
||||||
|
assert result["collection_b"]["ids"] == ["00000000-0000-0000-0000-000000000002"]
|
||||||
|
assert result["collection_b"]["vectors"] == [[0.3, 0.4]]
|
||||||
|
assert result["collection_b"]["payloads"] == [{"source_id": 2}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"collection_names,expected_collections",
|
||||||
|
[
|
||||||
|
(["col1", "col1", "col1"], 1),
|
||||||
|
(["col1", "col2", "col3"], 3),
|
||||||
|
(["col1", "col2", "col1", "col2"], 2),
|
||||||
|
(["single"], 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_by_collection_various_groupings(collection_names, expected_collections):
|
||||||
|
chunks = [
|
||||||
|
Chunk(
|
||||||
|
id=f"00000000-0000-0000-0000-00000000000{i}",
|
||||||
|
content=f"test content {i}",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[float(i)],
|
||||||
|
item_metadata={"index": i},
|
||||||
|
collection_name=collection_name,
|
||||||
|
)
|
||||||
|
for i, collection_name in enumerate(collection_names, 1)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = by_collection(chunks)
|
||||||
|
|
||||||
|
assert len(result) == expected_collections
|
||||||
|
# Verify all chunks are accounted for
|
||||||
|
total_chunks = sum(len(coll["ids"]) for coll in result.values())
|
||||||
|
assert total_chunks == len(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_by_collection_with_none_values():
|
||||||
|
chunks = [
|
||||||
|
Chunk(
|
||||||
|
id="00000000-0000-0000-0000-000000000001",
|
||||||
|
content="test content",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=None, # None vector
|
||||||
|
item_metadata=None, # None metadata
|
||||||
|
collection_name="test_collection",
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
id="00000000-0000-0000-0000-000000000002",
|
||||||
|
content="test content 2",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[0.1, 0.2],
|
||||||
|
item_metadata={"key": "value"},
|
||||||
|
collection_name="test_collection",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = by_collection(chunks)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "test_collection" in result
|
||||||
|
assert result["test_collection"]["ids"] == [
|
||||||
|
"00000000-0000-0000-0000-000000000001",
|
||||||
|
"00000000-0000-0000-0000-000000000002",
|
||||||
|
]
|
||||||
|
assert result["test_collection"]["vectors"] == [None, [0.1, 0.2]]
|
||||||
|
assert result["test_collection"]["payloads"] == [None, {"key": "value"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_by_collection_preserves_order():
|
||||||
|
chunks = []
|
||||||
|
for i in range(5):
|
||||||
|
chunks.append(
|
||||||
|
Chunk(
|
||||||
|
id=f"00000000-0000-0000-0000-00000000000{i}",
|
||||||
|
content=f"test content {i}",
|
||||||
|
embedding_model="test-model",
|
||||||
|
vector=[float(i)],
|
||||||
|
item_metadata={"order": i},
|
||||||
|
collection_name="ordered_collection",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = by_collection(chunks)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result["ordered_collection"]["ids"] == [
|
||||||
|
f"00000000-0000-0000-0000-00000000000{i}" for i in range(5)
|
||||||
|
]
|
||||||
|
assert result["ordered_collection"]["vectors"] == [[float(i)] for i in range(5)]
|
||||||
|
assert result["ordered_collection"]["payloads"] == [{"order": i} for i in range(5)]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"embedding_return,qdrant_error,expected_status,expected_embed_status",
|
"embedding_return,qdrant_error,expected_status,expected_embed_status",
|
||||||
[
|
[
|
||||||
@ -459,6 +654,7 @@ def test_process_content_item(
|
|||||||
embedding_model="test-model",
|
embedding_model="test-model",
|
||||||
vector=[0.1] * 1024,
|
vector=[0.1] * 1024,
|
||||||
item_metadata={"source_id": 1, "tags": ["test"]},
|
item_metadata={"source_id": 1, "tags": ["test"]},
|
||||||
|
collection_name="mail",
|
||||||
)
|
)
|
||||||
mock_chunks = [real_chunk]
|
mock_chunks = [real_chunk]
|
||||||
else: # empty
|
else: # empty
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from unittest.mock import patch, call
|
from unittest.mock import patch, call
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -350,6 +351,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
|||||||
id=chunk_id,
|
id=chunk_id,
|
||||||
source=item,
|
source=item,
|
||||||
content=f"Test chunk content {i}",
|
content=f"Test chunk content {i}",
|
||||||
|
collection_name=item.modality,
|
||||||
embedding_model="test-model",
|
embedding_model="test-model",
|
||||||
)
|
)
|
||||||
for i, chunk_id in enumerate(chunk_ids)
|
for i, chunk_id in enumerate(chunk_ids)
|
||||||
@ -358,7 +360,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
# Add vectors to Qdrant
|
# Add vectors to Qdrant
|
||||||
modality = "mail" if item_type == "MailMessage" else "blog"
|
modality = cast(str, item.modality)
|
||||||
qd.ensure_collection_exists(qdrant, modality, 1024)
|
qd.ensure_collection_exists(qdrant, modality, 1024)
|
||||||
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
|
qd.upsert_vectors(qdrant, modality, chunk_ids, [[1] * 1024] * len(chunk_ids))
|
||||||
|
|
||||||
@ -375,6 +377,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
|||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
content="New chunk content 1",
|
content="New chunk content 1",
|
||||||
embedding_model="test-model",
|
embedding_model="test-model",
|
||||||
|
collection_name=modality,
|
||||||
vector=[0.1] * 1024,
|
vector=[0.1] * 1024,
|
||||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||||
),
|
),
|
||||||
@ -382,6 +385,7 @@ def test_reingest_item_success(db_session, qdrant, item_type):
|
|||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
content="New chunk content 2",
|
content="New chunk content 2",
|
||||||
embedding_model="test-model",
|
embedding_model="test-model",
|
||||||
|
collection_name=modality,
|
||||||
vector=[0.2] * 1024,
|
vector=[0.2] * 1024,
|
||||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||||
),
|
),
|
||||||
@ -449,6 +453,7 @@ def test_reingest_item_no_chunks(db_session, qdrant):
|
|||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
content="New chunk content",
|
content="New chunk content",
|
||||||
embedding_model="test-model",
|
embedding_model="test-model",
|
||||||
|
collection_name=item.modality,
|
||||||
vector=[0.1] * 1024,
|
vector=[0.1] * 1024,
|
||||||
item_metadata={"source_id": item.id, "tags": ["test"]},
|
item_metadata={"source_id": item.id, "tags": ["test"]},
|
||||||
),
|
),
|
||||||
@ -538,6 +543,7 @@ def test_reingest_empty_source_items_success(db_session, item_type):
|
|||||||
source=item_with_chunks,
|
source=item_with_chunks,
|
||||||
content="Test chunk content",
|
content="Test chunk content",
|
||||||
embedding_model="test-model",
|
embedding_model="test-model",
|
||||||
|
collection_name=item_with_chunks.modality,
|
||||||
)
|
)
|
||||||
db_session.add(chunk)
|
db_session.add(chunk)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user