use proper chunk objects

This commit is contained in:
Daniel O'Connell 2025-05-03 16:13:38 +02:00
parent 453aed7c19
commit 44de394eb1
24 changed files with 2013 additions and 1697 deletions

View File

@ -1,8 +1,8 @@
"""Initial structure
"""Initial structure for the database.
Revision ID: a466a07360d5
Revises:
Create Date: 2025-04-27 17:15:37.487616
Revision ID: 4684845ca51e
Revises: a466a07360d5
Create Date: 2025-05-03 14:00:56.113840
"""
@ -13,7 +13,7 @@ import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "a466a07360d5"
revision: str = "4684845ca51e"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@ -21,12 +21,7 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
# Create enum type for github_item with IF NOT EXISTS
op.execute(
"DO $$ BEGIN CREATE TYPE gh_item_kind AS ENUM ('issue','pr','comment','project_card'); EXCEPTION WHEN duplicate_object THEN NULL; END $$;"
)
op.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")
op.create_table(
"email_accounts",
sa.Column("id", sa.BigInteger(), nullable=False),
@ -56,6 +51,22 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("email_address"),
)
op.create_index(
"email_accounts_active_idx",
"email_accounts",
["active", "last_sync_at"],
unique=False,
)
op.create_index(
"email_accounts_address_idx", "email_accounts", ["email_address"], unique=True
)
op.create_index(
"email_accounts_tags_idx",
"email_accounts",
["tags"],
unique=False,
postgresql_using="gin",
)
op.create_table(
"rss_feeds",
sa.Column("id", sa.BigInteger(), nullable=False),
@ -80,6 +91,16 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("url"),
)
op.create_index(
"rss_feeds_active_idx", "rss_feeds", ["active", "last_checked_at"], unique=False
)
op.create_index(
"rss_feeds_tags_idx",
"rss_feeds",
["tags"],
unique=False,
postgresql_using="gin",
)
op.create_table(
"source_item",
sa.Column("id", sa.BigInteger(), nullable=False),
@ -92,134 +113,103 @@ def upgrade() -> None:
nullable=True,
),
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
sa.Column("lang", sa.Text(), nullable=True),
sa.Column("model_hash", sa.Text(), nullable=True),
sa.Column(
"vector_ids", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column("embed_status", sa.Text(), server_default="RAW", nullable=False),
sa.Column("byte_length", sa.Integer(), nullable=True),
sa.Column("size", sa.Integer(), nullable=True),
sa.Column("mime_type", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("embed_status", sa.Text(), server_default="RAW", nullable=False),
sa.Column("type", sa.String(length=50), nullable=True),
sa.CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("sha256"),
)
op.create_index("source_filename_idx", "source_item", ["filename"], unique=False)
op.create_index("source_modality_idx", "source_item", ["modality"], unique=False)
op.create_index("source_status_idx", "source_item", ["embed_status"], unique=False)
op.create_index(
"source_tags_idx", "source_item", ["tags"], unique=False, postgresql_using="gin"
)
op.create_table(
"blog_post",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("url", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("url"),
)
op.create_table(
"book_doc",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("chapter", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"chat_message",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("platform", sa.Text(), nullable=True),
sa.Column("channel_id", sa.Text(), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("body_raw", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"git_commit",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("repo_path", sa.Text(), nullable=True),
sa.Column("commit_sha", sa.Text(), nullable=True),
sa.Column("author_name", sa.Text(), nullable=True),
sa.Column("author_email", sa.Text(), nullable=True),
sa.Column("author_date", sa.DateTime(timezone=True), nullable=True),
sa.Column("msg_raw", sa.Text(), nullable=True),
sa.Column("diff_summary", sa.Text(), nullable=True),
sa.Column("files_changed", sa.ARRAY(sa.Text()), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("commit_sha"),
op.create_index(
"chat_channel_idx", "chat_message", ["platform", "channel_id"], unique=False
)
op.create_table(
"mail_message",
sa.Column("id", sa.BigInteger(), nullable=False),
"chunk",
sa.Column(
"id",
sa.UUID(),
server_default=sa.text("uuid_generate_v4()"),
nullable=False,
),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.Text(), nullable=True),
sa.Column("subject", sa.Text(), nullable=True),
sa.Column("sender", sa.Text(), nullable=True),
sa.Column("recipients", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("body_raw", sa.Text(), nullable=True),
sa.Column("folder", sa.Text(), nullable=True),
sa.Column("tsv", postgresql.TSVECTOR(), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("message_id"),
)
op.create_table(
"email_attachment",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("mail_message_id", sa.BigInteger(), nullable=False),
sa.Column("filename", sa.Text(), nullable=False),
sa.Column("content_type", sa.Text(), nullable=True),
sa.Column("size", sa.Integer(), nullable=True),
sa.Column("content", postgresql.BYTEA(), nullable=True),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("embedding_model", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(
["mail_message_id"], ["mail_message.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["source_id"], ["source_item.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"misc_doc",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("path", sa.Text(), nullable=True),
sa.Column("mime_type", sa.Text(), nullable=True),
sa.CheckConstraint("(file_path IS NOT NULL) OR (content IS NOT NULL)"),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("chunk_source_idx", "chunk", ["source_id"], unique=False)
op.create_table(
"photo",
"git_commit",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("exif_taken_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("exif_lat", sa.Numeric(9, 6), nullable=True),
sa.Column("exif_lon", sa.Numeric(9, 6), nullable=True),
sa.Column("camera", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.Column("repo_path", sa.Text(), nullable=True),
sa.Column("commit_sha", sa.Text(), nullable=True),
sa.Column("author_name", sa.Text(), nullable=True),
sa.Column("author_email", sa.Text(), nullable=True),
sa.Column("author_date", sa.DateTime(timezone=True), nullable=True),
sa.Column("diff_summary", sa.Text(), nullable=True),
sa.Column("files_changed", sa.ARRAY(sa.Text()), nullable=True),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("commit_sha"),
)
op.create_index("git_date_idx", "git_commit", ["author_date"], unique=False)
op.create_index(
"git_files_idx",
"git_commit",
["files_changed"],
unique=False,
postgresql_using="gin",
)
# Add github_item table
op.create_table(
"github_item",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("kind", sa.Text(), nullable=False),
sa.Column("repo_path", sa.Text(), nullable=False),
sa.Column("number", sa.Integer(), nullable=True),
@ -227,7 +217,6 @@ def upgrade() -> None:
sa.Column("commit_sha", sa.Text(), nullable=True),
sa.Column("state", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("body_raw", sa.Text(), nullable=True),
sa.Column("labels", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
@ -235,147 +224,125 @@ def upgrade() -> None:
sa.Column("merged_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("diff_summary", sa.Text(), nullable=True),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.CheckConstraint("kind IN ('issue', 'pr', 'comment', 'project_card')"),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Add constraint to github_item.kind
op.create_check_constraint(
"github_item_kind_check",
op.create_index(
"gh_issue_lookup_idx",
"github_item",
"kind IN ('issue', 'pr', 'comment', 'project_card')",
)
# Add missing constraint to source_item
op.create_check_constraint(
"source_item_embed_status_check",
"source_item",
"embed_status IN ('RAW','QUEUED','STORED','FAILED')",
)
# Create trigger function for vector_ids validation
op.execute("""
CREATE OR REPLACE FUNCTION trg_vector_ids_not_empty()
RETURNS TRIGGER LANGUAGE plpgsql AS $$
BEGIN
IF NEW.embed_status = 'STORED'
AND (NEW.vector_ids IS NULL OR array_length(NEW.vector_ids,1) = 0) THEN
RAISE EXCEPTION
USING MESSAGE = 'vector_ids must not be empty when embed_status = STORED';
END IF;
RETURN NEW;
END;
$$;
""")
# Create trigger
op.execute("""
CREATE TRIGGER check_vector_ids
BEFORE UPDATE ON source_item
FOR EACH ROW EXECUTE FUNCTION trg_vector_ids_not_empty();
""")
# Create indexes for source_item
op.create_index("source_modality_idx", "source_item", ["modality"])
op.create_index("source_status_idx", "source_item", ["embed_status"])
op.create_index("source_tags_idx", "source_item", ["tags"], postgresql_using="gin")
# Create indexes for mail_message
op.create_index("mail_sent_idx", "mail_message", ["sent_at"])
op.create_index(
"mail_recipients_idx", "mail_message", ["recipients"], postgresql_using="gin"
)
op.create_index("email_attachment_filename_idx", "email_attachment", ["filename"], unique=False)
op.create_index("email_attachment_message_idx", "email_attachment", ["mail_message_id"], unique=False)
op.create_index("mail_tsv_idx", "mail_message", ["tsv"], postgresql_using="gin")
# Create index for chat_message
op.create_index("chat_channel_idx", "chat_message", ["platform", "channel_id"])
# Create indexes for git_commit
op.create_index(
"git_files_idx", "git_commit", ["files_changed"], postgresql_using="gin"
)
op.create_index("git_date_idx", "git_commit", ["author_date"])
# Create index for photo
op.create_index("photo_taken_idx", "photo", ["exif_taken_at"])
# Create indexes for rss_feeds
op.create_index("rss_feeds_active_idx", "rss_feeds", ["active", "last_checked_at"])
op.create_index("rss_feeds_tags_idx", "rss_feeds", ["tags"], postgresql_using="gin")
# Create indexes for email_accounts
op.create_index(
"email_accounts_address_idx", "email_accounts", ["email_address"], unique=True
["repo_path", "kind", "number"],
unique=False,
)
op.create_index(
"email_accounts_active_idx", "email_accounts", ["active", "last_sync_at"]
"gh_labels_idx", "github_item", ["labels"], unique=False, postgresql_using="gin"
)
op.create_index(
"email_accounts_tags_idx", "email_accounts", ["tags"], postgresql_using="gin"
"gh_repo_kind_idx", "github_item", ["repo_path", "kind"], unique=False
)
op.create_table(
"mail_message",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.Text(), nullable=True),
sa.Column("subject", sa.Text(), nullable=True),
sa.Column("sender", sa.Text(), nullable=True),
sa.Column("recipients", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("folder", sa.Text(), nullable=True),
sa.Column("tsv", postgresql.TSVECTOR(), nullable=True),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("message_id"),
)
# Create indexes for github_item
op.create_index("gh_repo_kind_idx", "github_item", ["repo_path", "kind"])
op.create_index(
"gh_issue_lookup_idx", "github_item", ["repo_path", "kind", "number"]
"mail_recipients_idx",
"mail_message",
["recipients"],
unique=False,
postgresql_using="gin",
)
op.create_index("mail_sent_idx", "mail_message", ["sent_at"], unique=False)
op.create_index(
"mail_tsv_idx", "mail_message", ["tsv"], unique=False, postgresql_using="gin"
)
op.create_table(
"misc_doc",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("path", sa.Text(), nullable=True),
sa.Column("mime_type", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"photo",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("exif_taken_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("exif_lat", sa.Numeric(precision=9, scale=6), nullable=True),
sa.Column("exif_lon", sa.Numeric(precision=9, scale=6), nullable=True),
sa.Column("camera", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("photo_taken_idx", "photo", ["exif_taken_at"], unique=False)
op.create_table(
"email_attachment",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("mail_message_id", sa.BigInteger(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["mail_message_id"], ["mail_message.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"email_attachment_message_idx",
"email_attachment",
["mail_message_id"],
unique=False,
)
op.create_index("gh_labels_idx", "github_item", ["labels"], postgresql_using="gin")
# Create add_tags helper function
op.execute("""
CREATE OR REPLACE FUNCTION add_tags(p_source BIGINT, p_tags TEXT[])
RETURNS VOID LANGUAGE SQL AS $$
UPDATE source_item
SET tags =
(SELECT ARRAY(SELECT DISTINCT unnest(tags || p_tags)))
WHERE id = p_source;
$$;
""")
def downgrade() -> None:
# Drop indexes
op.drop_index("gh_labels_idx", table_name="github_item")
op.drop_index("gh_issue_lookup_idx", table_name="github_item")
op.drop_index("gh_repo_kind_idx", table_name="github_item")
op.drop_index("email_accounts_tags_idx", table_name="email_accounts")
op.drop_index("email_accounts_active_idx", table_name="email_accounts")
op.drop_index("email_accounts_address_idx", table_name="email_accounts")
op.drop_index("rss_feeds_tags_idx", table_name="rss_feeds")
op.drop_index("rss_feeds_active_idx", table_name="rss_feeds")
op.drop_index("photo_taken_idx", table_name="photo")
op.drop_index("git_date_idx", table_name="git_commit")
op.drop_index("git_files_idx", table_name="git_commit")
op.drop_index("chat_channel_idx", table_name="chat_message")
op.drop_index("mail_tsv_idx", table_name="mail_message")
op.drop_index("mail_recipients_idx", table_name="mail_message")
op.drop_index("mail_sent_idx", table_name="mail_message")
op.drop_index("email_attachment_message_idx", table_name="email_attachment")
op.drop_index("email_attachment_filename_idx", table_name="email_attachment")
op.drop_index("source_tags_idx", table_name="source_item")
op.drop_index("source_status_idx", table_name="source_item")
op.drop_index("source_modality_idx", table_name="source_item")
# Drop tables
op.drop_table("email_attachment")
op.drop_index("photo_taken_idx", table_name="photo")
op.drop_table("photo")
op.drop_table("misc_doc")
op.drop_index("mail_tsv_idx", table_name="mail_message", postgresql_using="gin")
op.drop_index("mail_sent_idx", table_name="mail_message")
op.drop_index(
"mail_recipients_idx", table_name="mail_message", postgresql_using="gin"
)
op.drop_table("mail_message")
op.drop_index("gh_repo_kind_idx", table_name="github_item")
op.drop_index("gh_labels_idx", table_name="github_item", postgresql_using="gin")
op.drop_index("gh_issue_lookup_idx", table_name="github_item")
op.drop_table("github_item")
op.drop_index("git_files_idx", table_name="git_commit", postgresql_using="gin")
op.drop_index("git_date_idx", table_name="git_commit")
op.drop_table("git_commit")
op.drop_index("chunk_source_idx", table_name="chunk")
op.drop_table("chunk")
op.drop_index("chat_channel_idx", table_name="chat_message")
op.drop_table("chat_message")
op.drop_table("book_doc")
op.drop_table("blog_post")
op.drop_table("github_item")
op.drop_table("email_attachment")
op.drop_table("mail_message")
op.drop_index("source_tags_idx", table_name="source_item", postgresql_using="gin")
op.drop_index("source_status_idx", table_name="source_item")
op.drop_index("source_modality_idx", table_name="source_item")
op.drop_table("source_item")
op.drop_index("rss_feeds_tags_idx", table_name="rss_feeds", postgresql_using="gin")
op.drop_index("rss_feeds_active_idx", table_name="rss_feeds")
op.drop_table("rss_feeds")
op.drop_index(
"email_accounts_tags_idx", table_name="email_accounts", postgresql_using="gin"
)
op.drop_index("email_accounts_address_idx", table_name="email_accounts")
op.drop_index("email_accounts_active_idx", table_name="email_accounts")
op.drop_table("email_accounts")
# Drop triggers and functions
op.execute("DROP TRIGGER IF EXISTS check_vector_ids ON source_item")
op.execute("DROP FUNCTION IF EXISTS trg_vector_ids_not_empty()")
op.execute("DROP FUNCTION IF EXISTS add_tags(BIGINT, TEXT[])")
# Drop enum type
op.execute("DROP TYPE IF EXISTS gh_item_kind")

View File

@ -0,0 +1,130 @@
import logging
from typing import Iterable, Any
import re
logger = logging.getLogger(__name__)
# Chunking configuration
MAX_TOKENS = 32000 # VoyageAI max context window
OVERLAP_TOKENS = 200 # Default overlap between chunks
CHARS_PER_TOKEN = 4
Vector = list[float]
Embedding = tuple[str, Vector, dict[str, Any]]
# Regex for sentence splitting
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
words = text.split()
if not words:
return
current = ""
for word in words:
new_chunk = f"{current} {word}".strip()
if current and approx_token_count(new_chunk) > max_tokens:
yield current
current = word
else:
current = new_chunk
if current: # Only yield non-empty final chunk
yield current
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
"""
Yield text spans in priority order: paragraphs, sentences, words.
Each span is guaranteed to be under max_tokens.
Args:
text: The text to split
max_tokens: Maximum tokens per chunk
Yields:
Spans of text that fit within token limits
"""
# Return early for empty text
if not text.strip():
return
for paragraph in text.split("\n\n"):
if not paragraph.strip():
continue
if approx_token_count(paragraph) <= max_tokens:
yield paragraph
continue
for sentence in _SENT_SPLIT_RE.split(paragraph):
if not sentence.strip():
continue
if approx_token_count(sentence) <= max_tokens:
yield sentence
continue
for chunk in yield_word_chunks(sentence, max_tokens):
yield chunk
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
"""
Split text into chunks respecting semantic boundaries while staying within token limits.
Args:
text: The text to chunk
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
overlap: Number of tokens to overlap between chunks (default: 200)
Returns:
List of text chunks
"""
text = text.strip()
if not text:
return
if approx_token_count(text) <= max_tokens:
yield text
return
overlap_chars = overlap * CHARS_PER_TOKEN
current = ""
for span in yield_spans(text, max_tokens):
current = f"{current} {span}".strip()
if approx_token_count(current) < max_tokens:
continue
if overlap <= 0:
yield current
current = ""
continue
overlap_text = current[-overlap_chars:]
clean_break = max(
overlap_text.rfind(". "),
overlap_text.rfind("! "),
overlap_text.rfind("? ")
)
if clean_break < 0:
yield current
current = ""
continue
break_offset = -overlap_chars + clean_break + 1
chunk = current[break_offset:].strip()
yield current
current = chunk
if current:
yield current.strip()

View File

@ -1,257 +1,351 @@
"""
Database models for the knowledge base system.
"""
import pathlib
import re
from email.message import EmailMessage
from pathlib import Path
from typing import Any
from PIL import Image
from sqlalchemy import (
Column, ForeignKey, Integer, BigInteger, Text, DateTime, Boolean,
ARRAY, func, Numeric, CheckConstraint, Index
ARRAY,
UUID,
BigInteger,
Boolean,
CheckConstraint,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
func,
)
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from memory.common import settings
from memory.common.parsers.email import parse_email_message
Base = declarative_base()
def clean_filename(filename: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
class Chunk(Base):
"""Stores content chunks with their vector embeddings."""
__tablename__ = "chunk"
# The ID is also used as the vector ID in the vector database
id = Column(
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
)
source_id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
)
file_path = Column(Text) # Path to content if stored as a file
content = Column(Text) # Direct content storage
embedding_model = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now())
vector = list[float] | None # the vector generated by the embedding model
item_metadata = dict[str, Any] | None
# One of file_path or content must be populated
__table_args__ = (
CheckConstraint("(file_path IS NOT NULL) OR (content IS NOT NULL)"),
Index("chunk_source_idx", "source_id"),
)
@property
def data(self) -> list[bytes | str | Image.Image]:
if not self.file_path:
return [self.content]
path = pathlib.Path(self.file_path)
if self.file_path.endswith("*"):
files = list(path.parent.glob(path.name))
else:
files = [path]
items = []
for file_path in files:
if file_path.suffix == ".png":
items.append(Image.open(file_path))
elif file_path.suffix == ".bin":
items.append(file_path.read_bytes())
else:
items.append(file_path.read_text())
return items
class SourceItem(Base):
__tablename__ = 'source_item'
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
__tablename__ = "source_item"
id = Column(BigInteger, primary_key=True)
modality = Column(Text, nullable=False)
sha256 = Column(BYTEA, nullable=False, unique=True)
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
tags = Column(ARRAY(Text), nullable=False, server_default='{}')
lang = Column(Text)
model_hash = Column(Text)
vector_ids = Column(ARRAY(Text), nullable=False, server_default='{}')
embed_status = Column(Text, nullable=False, server_default='RAW')
byte_length = Column(Integer)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
size = Column(Integer)
mime_type = Column(Text)
mail_message = relationship("MailMessage", back_populates="source", uselist=False)
attachments = relationship("EmailAttachment", back_populates="source", cascade="all, delete-orphan", uselist=False)
# Content is stored in the database if it's small enough and text
content = Column(Text)
# Otherwise the content is stored on disk
filename = Column(Text, nullable=True)
# Chunks relationship
embed_status = Column(Text, nullable=False, server_default="RAW")
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
# Discriminator column for SQLAlchemy inheritance
type = Column(String(50))
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
# Add table-level constraint and indexes
__table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
Index('source_modality_idx', 'modality'),
Index('source_status_idx', 'embed_status'),
Index('source_tags_idx', 'tags', postgresql_using='gin'),
Index("source_modality_idx", "modality"),
Index("source_status_idx", "embed_status"),
Index("source_tags_idx", "tags", postgresql_using="gin"),
Index("source_filename_idx", "filename"),
)
@property
def vector_ids(self):
"""Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks]
class MailMessage(Base):
__tablename__ = 'mail_message'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
class MailMessage(SourceItem):
__tablename__ = "mail_message"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
message_id = Column(Text, unique=True)
subject = Column(Text)
sender = Column(Text)
recipients = Column(ARRAY(Text))
sent_at = Column(DateTime(timezone=True))
body_raw = Column(Text)
folder = Column(Text)
tsv = Column(TSVECTOR)
attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan")
source = relationship("SourceItem", back_populates="mail_message")
def __init__(self, **kwargs):
if not kwargs.get("modality"):
kwargs["modality"] = "email"
super().__init__(**kwargs)
attachments = relationship(
"EmailAttachment",
back_populates="mail_message",
foreign_keys="EmailAttachment.mail_message_id",
cascade="all, delete-orphan",
)
__mapper_args__ = {
"polymorphic_identity": "mail_message",
}
@property
def attachments_path(self) -> Path:
return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX')
clean_sender = clean_filename(self.sender)
clean_folder = clean_filename(self.folder or "INBOX")
return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder
def safe_filename(self, filename: str) -> Path:
suffix = Path(filename).suffix
name = clean_filename(filename.removesuffix(suffix)) + suffix
path = self.attachments_path / name
path.parent.mkdir(parents=True, exist_ok=True)
return path
def as_payload(self) -> dict:
return {
"source_id": self.source_id,
"source_id": self.id,
"message_id": self.message_id,
"subject": self.subject,
"sender": self.sender,
"recipients": self.recipients,
"folder": self.folder,
"tags": self.source.tags,
"tags": self.tags,
"date": self.sent_at and self.sent_at.isoformat() or None,
}
@property
def parsed_content(self) -> EmailMessage:
return parse_email_message(self.content, self.message_id)
@property
def body(self) -> str:
return self.parsed_content["body"]
# Add indexes
__table_args__ = (
Index('mail_sent_idx', 'sent_at'),
Index('mail_recipients_idx', 'recipients', postgresql_using='gin'),
Index('mail_tsv_idx', 'tsv', postgresql_using='gin'),
Index("mail_sent_idx", "sent_at"),
Index("mail_recipients_idx", "recipients", postgresql_using="gin"),
Index("mail_tsv_idx", "tsv", postgresql_using="gin"),
)
class EmailAttachment(Base):
__tablename__ = 'email_attachment'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
mail_message_id = Column(BigInteger, ForeignKey('mail_message.id', ondelete='CASCADE'), nullable=False)
filename = Column(Text, nullable=False)
content_type = Column(Text)
size = Column(Integer)
content = Column(BYTEA) # For small files stored inline
file_path = Column(Text) # For larger files stored on disk
class EmailAttachment(SourceItem):
__tablename__ = "email_attachment"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
mail_message_id = Column(
BigInteger, ForeignKey("mail_message.id", ondelete="CASCADE"), nullable=False
)
created_at = Column(DateTime(timezone=True), server_default=func.now())
mail_message = relationship("MailMessage", back_populates="attachments")
source = relationship("SourceItem", back_populates="attachments")
mail_message = relationship(
"MailMessage", back_populates="attachments", foreign_keys=[mail_message_id]
)
__mapper_args__ = {
"polymorphic_identity": "email_attachment",
}
def as_payload(self) -> dict:
return {
"filename": self.filename,
"content_type": self.content_type,
"content_type": self.mime_type,
"size": self.size,
"created_at": self.created_at and self.created_at.isoformat() or None,
"mail_message_id": self.mail_message_id,
"source_id": self.mail_message.source_id,
"tags": self.mail_message.source.tags,
"source_id": self.id,
"tags": self.tags,
}
# Add indexes
__table_args__ = (
Index('email_attachment_message_idx', 'mail_message_id'),
Index('email_attachment_filename_idx', 'filename'),
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
class ChatMessage(SourceItem):
__tablename__ = "chat_message"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
class ChatMessage(Base):
__tablename__ = 'chat_message'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
platform = Column(Text)
channel_id = Column(Text)
author = Column(Text)
sent_at = Column(DateTime(timezone=True))
body_raw = Column(Text)
__mapper_args__ = {
"polymorphic_identity": "chat_message",
}
# Add index
__table_args__ = (
Index('chat_channel_idx', 'platform', 'channel_id'),
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
class GitCommit(SourceItem):
__tablename__ = "git_commit"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
class GitCommit(Base):
__tablename__ = 'git_commit'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
repo_path = Column(Text)
commit_sha = Column(Text, unique=True)
author_name = Column(Text)
author_email = Column(Text)
author_date = Column(DateTime(timezone=True))
msg_raw = Column(Text)
diff_summary = Column(Text)
files_changed = Column(ARRAY(Text))
__mapper_args__ = {
"polymorphic_identity": "git_commit",
}
# Add indexes
__table_args__ = (
Index('git_files_idx', 'files_changed', postgresql_using='gin'),
Index('git_date_idx', 'author_date'),
Index("git_files_idx", "files_changed", postgresql_using="gin"),
Index("git_date_idx", "author_date"),
)
class Photo(Base):
__tablename__ = 'photo'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
file_path = Column(Text)
class Photo(SourceItem):
__tablename__ = "photo"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
exif_taken_at = Column(DateTime(timezone=True))
exif_lat = Column(Numeric(9, 6))
exif_lon = Column(Numeric(9, 6))
camera = Column(Text)
__mapper_args__ = {
"polymorphic_identity": "photo",
}
# Add index
__table_args__ = (
Index('photo_taken_idx', 'exif_taken_at'),
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
class BookDoc(SourceItem):
__tablename__ = "book_doc"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
class BookDoc(Base):
__tablename__ = 'book_doc'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
title = Column(Text)
author = Column(Text)
chapter = Column(Text)
published = Column(DateTime(timezone=True))
__mapper_args__ = {
"polymorphic_identity": "book_doc",
}
class BlogPost(Base):
__tablename__ = 'blog_post'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
class BlogPost(SourceItem):
__tablename__ = "blog_post"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
url = Column(Text, unique=True)
title = Column(Text)
published = Column(DateTime(timezone=True))
__mapper_args__ = {
"polymorphic_identity": "blog_post",
}
class MiscDoc(Base):
__tablename__ = 'misc_doc'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
class MiscDoc(SourceItem):
__tablename__ = "misc_doc"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
path = Column(Text)
mime_type = Column(Text)
__mapper_args__ = {
"polymorphic_identity": "misc_doc",
}
class RssFeed(Base):
__tablename__ = 'rss_feeds'
id = Column(BigInteger, primary_key=True)
url = Column(Text, nullable=False, unique=True)
title = Column(Text)
description = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default='{}')
last_checked_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default='true')
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
# Add indexes
__table_args__ = (
Index('rss_feeds_active_idx', 'active', 'last_checked_at'),
Index('rss_feeds_tags_idx', 'tags', postgresql_using='gin'),
class GithubItem(SourceItem):
__tablename__ = "github_item"
id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
)
class EmailAccount(Base):
__tablename__ = 'email_accounts'
id = Column(BigInteger, primary_key=True)
name = Column(Text, nullable=False)
email_address = Column(Text, nullable=False, unique=True)
imap_server = Column(Text, nullable=False)
imap_port = Column(Integer, nullable=False, server_default='993')
username = Column(Text, nullable=False)
password = Column(Text, nullable=False)
use_ssl = Column(Boolean, nullable=False, server_default='true')
folders = Column(ARRAY(Text), nullable=False, server_default='{}')
tags = Column(ARRAY(Text), nullable=False, server_default='{}')
last_sync_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default='true')
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
# Add indexes
__table_args__ = (
Index('email_accounts_address_idx', 'email_address', unique=True),
Index('email_accounts_active_idx', 'active', 'last_sync_at'),
Index('email_accounts_tags_idx', 'tags', postgresql_using='gin'),
)
class GithubItem(Base):
__tablename__ = 'github_item'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
kind = Column(Text, nullable=False)
repo_path = Column(Text, nullable=False)
number = Column(Integer)
@ -259,19 +353,76 @@ class GithubItem(Base):
commit_sha = Column(Text)
state = Column(Text)
title = Column(Text)
body_raw = Column(Text)
labels = Column(ARRAY(Text))
author = Column(Text)
created_at = Column(DateTime(timezone=True))
closed_at = Column(DateTime(timezone=True))
merged_at = Column(DateTime(timezone=True))
diff_summary = Column(Text)
payload = Column(JSONB)
__mapper_args__ = {
"polymorphic_identity": "github_item",
}
__table_args__ = (
CheckConstraint("kind IN ('issue', 'pr', 'comment', 'project_card')"),
Index('gh_repo_kind_idx', 'repo_path', 'kind'),
Index('gh_issue_lookup_idx', 'repo_path', 'kind', 'number'),
Index('gh_labels_idx', 'labels', postgresql_using='gin'),
)
Index("gh_repo_kind_idx", "repo_path", "kind"),
Index("gh_issue_lookup_idx", "repo_path", "kind", "number"),
Index("gh_labels_idx", "labels", postgresql_using="gin"),
)
class RssFeed(Base):
__tablename__ = "rss_feeds"
id = Column(BigInteger, primary_key=True)
url = Column(Text, nullable=False, unique=True)
title = Column(Text)
description = Column(Text)
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
last_checked_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# Add indexes
__table_args__ = (
Index("rss_feeds_active_idx", "active", "last_checked_at"),
Index("rss_feeds_tags_idx", "tags", postgresql_using="gin"),
)
class EmailAccount(Base):
__tablename__ = "email_accounts"
id = Column(BigInteger, primary_key=True)
name = Column(Text, nullable=False)
email_address = Column(Text, nullable=False, unique=True)
imap_server = Column(Text, nullable=False)
imap_port = Column(Integer, nullable=False, server_default="993")
username = Column(Text, nullable=False)
password = Column(Text, nullable=False)
use_ssl = Column(Boolean, nullable=False, server_default="true")
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
last_sync_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default="true")
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# Add indexes
__table_args__ = (
Index("email_accounts_address_idx", "email_address", unique=True),
Index("email_accounts_active_idx", "active", "last_sync_at"),
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
)

View File

@ -1,9 +1,17 @@
import logging
import pathlib
from typing import Literal, TypedDict, Iterable, Any
import voyageai
import re
import uuid
from typing import Any, Iterable, Literal, TypedDict
import voyageai
from PIL import Image
from memory.common import extract, settings
from memory.common.chunker import chunk_text
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
# Chunking configuration
MAX_TOKENS = 32000 # VoyageAI max context window
@ -27,16 +35,24 @@ DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1024, "distance": "Cosine"},
"chat": {"dimension": 1024, "distance": "Cosine"},
"git": {"dimension": 1024, "distance": "Cosine"},
"photo": {"dimension": 512, "distance": "Cosine"},
"book": {"dimension": 1024, "distance": "Cosine"},
"blog": {"dimension": 1024, "distance": "Cosine"},
"text": {"dimension": 1024, "distance": "Cosine"},
# Multimodal
"photo": {"dimension": 1024, "distance": "Cosine"},
"doc": {"dimension": 1024, "distance": "Cosine"},
}
TYPES = {
"doc": ["text/*"],
"doc": ["application/pdf", "application/docx", "application/msword"],
"text": ["text/*"],
"blog": ["text/markdown", "text/html"],
"photo": ["image/*"],
"book": ["application/pdf", "application/epub+zip", "application/mobi", "application/x-mobipocket-ebook"],
"book": [
"application/epub+zip",
"application/mobi",
"application/x-mobipocket-ebook",
],
}
@ -52,143 +68,55 @@ def get_modality(mime_type: str) -> str:
return "unknown"
# Regex for sentence splitting
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
def approx_token_count(s: str) -> int:
return len(s) // CHARS_PER_TOKEN
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
words = text.split()
if not words:
return
current = ""
for word in words:
new_chunk = f"{current} {word}".strip()
if current and approx_token_count(new_chunk) > max_tokens:
yield current
current = word
else:
current = new_chunk
if current: # Only yield non-empty final chunk
yield current
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
"""
Yield text spans in priority order: paragraphs, sentences, words.
Each span is guaranteed to be under max_tokens.
Args:
text: The text to split
max_tokens: Maximum tokens per chunk
Yields:
Spans of text that fit within token limits
"""
# Return early for empty text
if not text.strip():
return
for paragraph in text.split("\n\n"):
if not paragraph.strip():
continue
if approx_token_count(paragraph) <= max_tokens:
yield paragraph
continue
for sentence in _SENT_SPLIT_RE.split(paragraph):
if not sentence.strip():
continue
if approx_token_count(sentence) <= max_tokens:
yield sentence
continue
for chunk in yield_word_chunks(sentence, max_tokens):
yield chunk
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
"""
Split text into chunks respecting semantic boundaries while staying within token limits.
Args:
text: The text to chunk
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
overlap: Number of tokens to overlap between chunks (default: 200)
Returns:
List of text chunks
"""
text = text.strip()
if not text:
return
if approx_token_count(text) <= max_tokens:
yield text
return
overlap_chars = overlap * CHARS_PER_TOKEN
current = ""
for span in yield_spans(text, max_tokens):
current = f"{current} {span}".strip()
if approx_token_count(current) < max_tokens:
continue
if overlap <= 0:
yield current
current = ""
continue
overlap_text = current[-overlap_chars:]
clean_break = max(
overlap_text.rfind(". "),
overlap_text.rfind("! "),
overlap_text.rfind("? ")
)
if clean_break < 0:
yield current
current = ""
continue
break_offset = -overlap_chars + clean_break + 1
chunk = current[break_offset:].strip()
yield current
current = chunk
if current:
yield current.strip()
def embed_chunks(chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
def embed_chunks(
chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL
) -> list[Vector]:
vo = voyageai.Client()
return vo.embed(chunks, model=model).embeddings
if model == settings.MIXED_EMBEDDING_MODEL:
return vo.multimodal_embed(
chunks, model=model, input_type="document"
).embeddings
return vo.embed(chunks, model=model, input_type="document").embeddings
def embed_text(texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
chunks = [c for text in texts for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
return embed_chunks(chunks, model)
def embed_text(
texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL
) -> list[Vector]:
chunks = [
c
for text in texts
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
if c.strip()
]
if not chunks:
return []
try:
return embed_chunks(chunks, model)
except voyageai.error.InvalidRequestError as e:
logger.error(f"Error embedding text: {e}")
logger.debug(f"Text: {texts}")
raise
def embed_file(file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
def embed_file(
file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL
) -> list[Vector]:
return embed_text([file_path.read_text()], model)
def embed_mixed(items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL) -> list[Vector]:
def embed_mixed(
items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL
) -> list[Vector]:
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
if isinstance(item, str):
return [c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
return [
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
]
return [item]
chunks = [c for item in items for c in to_chunks(item)]
return embed_chunks(chunks, model)
return embed_chunks([chunks], model)
def embed_page(page: dict[str, Any]) -> list[Vector]:
@ -198,17 +126,64 @@ def embed_page(page: dict[str, Any]) -> list[Vector]:
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL)
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
if isinstance(item, str):
filename = settings.FILE_STORAGE_DIR / f"{chunk_id}.txt"
filename.write_text(item)
elif isinstance(item, bytes):
filename = settings.FILE_STORAGE_DIR / f"{chunk_id}.bin"
filename.write_bytes(item)
elif isinstance(item, Image.Image):
filename = settings.FILE_STORAGE_DIR / f"{chunk_id}.png"
item.save(filename)
else:
raise ValueError(f"Unsupported content type: {type(item)}")
return filename
def make_chunk(
page: extract.Page, vector: Vector, metadata: dict[str, Any] = {}
) -> Chunk:
"""Create a Chunk object from a page and a vector.
This is quite complex, because we need to handle the case where the page is a single string,
a single image, or a list of strings and images.
"""
chunk_id = str(uuid.uuid4())
contents = page["contents"]
content, filename = None, None
if all(isinstance(c, str) for c in contents):
content = "\n\n".join(contents)
model = settings.TEXT_EMBEDDING_MODEL
elif len(contents) == 1:
filename = (write_to_file(chunk_id, contents[0]),)
model = settings.MIXED_EMBEDDING_MODEL
else:
for i, item in enumerate(contents):
write_to_file(f"{chunk_id}_{i}", item)
model = settings.MIXED_EMBEDDING_MODEL
filename = (settings.FILE_STORAGE_DIR / f"{chunk_id}_*",)
return Chunk(
id=chunk_id,
file_path=filename,
content=content,
embedding_model=model,
vector=vector,
item_metadata=metadata,
)
def embed(
mime_type: str,
content: bytes | str | pathlib.Path,
metadata: dict[str, Any] = {},
) -> tuple[str, list[Embedding]]:
modality = get_modality(mime_type)
pages = extract.extract_content(mime_type, content)
vectors = [
(str(uuid.uuid4()), vector, page.get("metadata", {}) | metadata)
chunks = [
make_chunk(page, vector, metadata)
for page in pages
for vector in embed_page(page)
]
return modality, vectors
return modality, chunks

View File

@ -7,6 +7,7 @@ from PIL import Image
from typing import Any, TypedDict, Generator
MulitmodalChunk = Image.Image | str
class Page(TypedDict):
contents: list[MulitmodalChunk]
@ -33,14 +34,16 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]:
with as_file(content) as file_path:
with pymupdf.open(file_path) as pdf:
return [{
"contents": page_to_image(page),
"metadata": {
"page": page.number,
"width": page.rect.width,
"height": page.rect.height,
}
} for page in pdf.pages()]
return [
{
"contents": page_to_image(page),
"metadata": {
"page": page.number,
"width": page.rect.width,
"height": page.rect.height,
}
} for page in pdf.pages()
]
def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
@ -50,7 +53,7 @@ def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
image = Image.open(io.BytesIO(content))
else:
raise ValueError(f"Unsupported content type: {type(content)}")
return [{"contents": image, "metadata": {}}]
return [{"contents": [image], "metadata": {}}]
def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:

View File

@ -0,0 +1,177 @@
import email
import hashlib
import logging
from datetime import datetime
from email.utils import parsedate_to_datetime
from typing import TypedDict, Literal
import pathlib
logger = logging.getLogger(__name__)
class Attachment(TypedDict):
filename: str
content_type: str
size: int
content: bytes
path: pathlib.Path
class EmailMessage(TypedDict):
message_id: str
subject: str
sender: str
recipients: list[str]
sent_at: datetime | None
body: str
attachments: list[Attachment]
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
def extract_recipients(msg: email.message.Message) -> list[str]:
"""
Extract email recipients from message headers.
Args:
msg: Email message object
Returns:
List of recipient email addresses
"""
return [
recipient
for field in ["To", "Cc", "Bcc"]
if (field_value := msg.get(field, ""))
for r in field_value.split(",")
if (recipient := r.strip())
]
def extract_date(msg: email.message.Message) -> datetime | None:
"""
Parse date from email header.
Args:
msg: Email message object
Returns:
Parsed datetime or None if parsing failed
"""
if date_str := msg.get("Date"):
try:
return parsedate_to_datetime(date_str)
except Exception:
logger.warning(f"Could not parse date: {date_str}")
return None
def extract_body(msg: email.message.Message) -> str:
"""
Extract plain text body from email message.
Args:
msg: Email message object
Returns:
Plain text body content
"""
body = ""
if not msg.is_multipart():
try:
return msg.get_payload(decode=True).decode(errors='replace')
except Exception as e:
logger.error(f"Error decoding message body: {str(e)}")
return ""
for part in msg.walk():
content_type = part.get_content_type()
content_disposition = str(part.get("Content-Disposition", ""))
if content_type == "text/plain" and "attachment" not in content_disposition:
try:
body += part.get_payload(decode=True).decode(errors='replace') + "\n"
except Exception as e:
logger.error(f"Error decoding message part: {str(e)}")
return body
def extract_attachments(msg: email.message.Message) -> list[Attachment]:
"""
Extract attachment metadata and content from email.
Args:
msg: Email message object
Returns:
List of attachment dictionaries with metadata and content
"""
if not msg.is_multipart():
return []
attachments = []
for part in msg.walk():
content_disposition = part.get("Content-Disposition", "")
if "attachment" not in content_disposition:
continue
if filename := part.get_filename():
try:
content = part.get_payload(decode=True)
attachments.append({
"filename": filename,
"content_type": part.get_content_type(),
"size": len(content),
"content": content
})
except Exception as e:
logger.error(f"Error extracting attachment content for {filename}: {str(e)}")
return attachments
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
"""
Compute a SHA-256 hash of message content.
Args:
msg_id: Message ID
subject: Email subject
sender: Sender email
body: Message body
Returns:
SHA-256 hash as bytes
"""
hash_content = (msg_id + subject + sender + body).encode()
return hashlib.sha256(hash_content).digest()
def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
"""
Parse raw email into structured data.
Args:
raw_email: Raw email content as string
Returns:
Dict with parsed email data
"""
msg = email.message_from_string(raw_email)
message_id = msg.get("Message-ID") or f"generated-{message_id}"
subject = msg.get("Subject", "")
from_ = msg.get("From", "")
body = extract_body(msg)
return EmailMessage(
message_id=message_id,
subject=subject,
sender=from_,
recipients=extract_recipients(msg),
sent_at=extract_date(msg),
body=body,
attachments=extract_attachments(msg),
hash=compute_message_hash(message_id, subject, from_, body)
)

View File

@ -84,7 +84,9 @@ def initialize_collections(client: qdrant_client.QdrantClient, collections: dict
if collections is None:
collections = DEFAULT_COLLECTIONS
logger.info(f"Initializing collections:")
for name, params in collections.items():
logger.info(f" - {name}")
ensure_collection_exists(
client,
collection_name=name,

View File

@ -39,8 +39,8 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
# Worker settings
BEAT_LOOP_INTERVAL = int(os.getenv("BEAT_LOOP_INTERVAL", 3600))
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 600))
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
EMAIL_SYNC_INTERVAL = 60
# Embedding settings

View File

@ -2,6 +2,7 @@ import os
from celery import Celery
from memory.common import settings
def rabbit_url() -> str:
user = os.getenv("RABBITMQ_USER", "guest")
password = os.getenv("RABBITMQ_PASSWORD", "guest")
@ -32,3 +33,10 @@ app.conf.update(
"memory.workers.tasks.docs.*": {"queue": "docs"},
},
)
@app.on_after_configure.connect
def ensure_qdrant_initialised(sender, **_):
from memory.common import qdrant
qdrant.setup_qdrant()

View File

@ -1,200 +1,80 @@
import email
import hashlib
import imaplib
import logging
import re
import uuid
import base64
from contextlib import contextmanager
from datetime import datetime
from email.utils import parsedate_to_datetime
from typing import Generator, Callable, TypedDict, Literal
from typing import Generator, Callable
import pathlib
from sqlalchemy.orm import Session
from collections import defaultdict
from memory.common import settings, embedding
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common import qdrant
from memory.common import settings, embedding, qdrant
from memory.common.db.models import (
EmailAccount,
MailMessage,
SourceItem,
EmailAttachment,
)
from memory.common.parsers.email import (
Attachment,
parse_email_message,
RawEmailResponse,
)
logger = logging.getLogger(__name__)
class Attachment(TypedDict):
filename: str
content_type: str
size: int
content: bytes
path: pathlib.Path
class EmailMessage(TypedDict):
message_id: str
subject: str
sender: str
recipients: list[str]
sent_at: datetime | None
body: str
attachments: list[Attachment]
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
def extract_recipients(msg: email.message.Message) -> list[str]:
"""
Extract email recipients from message headers.
Args:
msg: Email message object
Returns:
List of recipient email addresses
"""
return [
recipient
for field in ["To", "Cc", "Bcc"]
if (field_value := msg.get(field, ""))
for r in field_value.split(",")
if (recipient := r.strip())
]
def extract_date(msg: email.message.Message) -> datetime | None:
"""
Parse date from email header.
Args:
msg: Email message object
Returns:
Parsed datetime or None if parsing failed
"""
if date_str := msg.get("Date"):
try:
return parsedate_to_datetime(date_str)
except Exception:
logger.warning(f"Could not parse date: {date_str}")
return None
def extract_body(msg: email.message.Message) -> str:
"""
Extract plain text body from email message.
Args:
msg: Email message object
Returns:
Plain text body content
"""
body = ""
if not msg.is_multipart():
try:
return msg.get_payload(decode=True).decode(errors='replace')
except Exception as e:
logger.error(f"Error decoding message body: {str(e)}")
return ""
for part in msg.walk():
content_type = part.get_content_type()
content_disposition = str(part.get("Content-Disposition", ""))
if content_type == "text/plain" and "attachment" not in content_disposition:
try:
body += part.get_payload(decode=True).decode(errors='replace') + "\n"
except Exception as e:
logger.error(f"Error decoding message part: {str(e)}")
return body
def extract_attachments(msg: email.message.Message) -> list[Attachment]:
"""
Extract attachment metadata and content from email.
Args:
msg: Email message object
Returns:
List of attachment dictionaries with metadata and content
"""
if not msg.is_multipart():
return []
attachments = []
for part in msg.walk():
content_disposition = part.get("Content-Disposition", "")
if "attachment" not in content_disposition:
continue
if filename := part.get_filename():
try:
content = part.get_payload(decode=True)
attachments.append({
"filename": filename,
"content_type": part.get_content_type(),
"size": len(content),
"content": content
})
except Exception as e:
logger.error(f"Error extracting attachment content for {filename}: {str(e)}")
return attachments
def process_attachment(attachment: Attachment, message: MailMessage) -> EmailAttachment | None:
def process_attachment(
attachment: Attachment, message: MailMessage
) -> EmailAttachment | None:
"""Process an attachment, storing large files on disk and returning metadata.
Args:
attachment: Attachment dictionary with metadata and content
message_id: Email message ID to use in file path generation
Returns:
Processed attachment dictionary with appropriate metadata
"""
content, file_path = None, None
if not (real_content := attachment.get("content")):
"No content, so just save the metadata"
elif attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE:
content = base64.b64encode(real_content)
elif attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE and attachment[
"content_type"
].startswith("text/"):
content = real_content.decode("utf-8", errors="replace")
else:
safe_filename = re.sub(r'[/\\]', '_', attachment["filename"])
user_dir = message.attachments_path
user_dir.mkdir(parents=True, exist_ok=True)
file_path = user_dir / safe_filename
file_path = message.safe_filename(attachment["filename"])
try:
file_path.write_bytes(real_content)
except Exception as e:
logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}")
logger.error(f"Failed to save attachment {file_path} to disk: {str(e)}")
return None
source_item = SourceItem(
modality=embedding.get_modality(attachment["content_type"]),
sha256=hashlib.sha256(real_content if real_content else str(attachment).encode()).digest(),
tags=message.source.tags,
byte_length=attachment["size"],
mime_type=attachment["content_type"],
)
return EmailAttachment(
source=source_item,
filename=attachment["filename"],
modality=embedding.get_modality(attachment["content_type"]),
sha256=hashlib.sha256(
real_content if real_content else str(attachment).encode()
).digest(),
tags=message.tags,
size=attachment["size"],
mime_type=attachment["content_type"],
mail_message=message,
content_type=attachment.get("content_type"),
size=attachment.get("size"),
content=content,
file_path=file_path and str(file_path),
filename=file_path and str(file_path),
)
def process_attachments(attachments: list[Attachment], message: MailMessage) -> list[EmailAttachment]:
def process_attachments(
attachments: list[Attachment], message: MailMessage
) -> list[EmailAttachment]:
"""
Process email attachments, storing large files on disk and returning metadata.
Args:
attachments: List of attachment dictionaries with metadata and content
message_id: Email message ID to use in file path generation
Returns:
List of processed attachment dictionaries with appropriate metadata
"""
@ -203,150 +83,111 @@ def process_attachments(attachments: list[Attachment], message: MailMessage) ->
return [
attachment
for a in attachments if (attachment := process_attachment(a, message))
for a in attachments
if (attachment := process_attachment(a, message))
]
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
"""
Compute a SHA-256 hash of message content.
Args:
msg_id: Message ID
subject: Email subject
sender: Sender email
body: Message body
Returns:
SHA-256 hash as bytes
"""
hash_content = (msg_id + subject + sender + body).encode()
return hashlib.sha256(hash_content).digest()
def parse_email_message(raw_email: str) -> EmailMessage:
"""
Parse raw email into structured data.
Args:
raw_email: Raw email content as string
Returns:
Dict with parsed email data
"""
msg = email.message_from_string(raw_email)
return EmailMessage(
message_id=msg.get("Message-ID", ""),
subject=msg.get("Subject", ""),
sender=msg.get("From", ""),
recipients=extract_recipients(msg),
sent_at=extract_date(msg),
body=extract_body(msg),
attachments=extract_attachments(msg)
)
def create_source_item(
db_session: Session,
message_hash: bytes,
account_tags: list[str],
raw_size: int,
modality: str = "mail",
mime_type: str = "message/rfc822",
) -> SourceItem:
"""
Create a new source item record.
Args:
db_session: Database session
message_hash: SHA-256 hash of message
account_tags: Tags from the email account
raw_size: Size of raw email in bytes
Returns:
Newly created SourceItem
"""
source_item = SourceItem(
modality=modality,
sha256=message_hash,
tags=account_tags,
byte_length=raw_size,
mime_type=mime_type,
embed_status="RAW"
)
db_session.add(source_item)
db_session.flush()
return source_item
def create_mail_message(
db_session: Session,
source_item: SourceItem,
parsed_email: EmailMessage,
tags: list[str],
folder: str,
raw_email: str,
message_id: str,
) -> MailMessage:
"""
Create a new mail message record and associated attachments.
Args:
db_session: Database session
source_id: ID of the SourceItem
parsed_email: Parsed email data
folder: IMAP folder name
Returns:
Newly created MailMessage
"""
parsed_email = parse_email_message(raw_email, message_id)
mail_message = MailMessage(
source=source_item,
modality="mail",
sha256=parsed_email["hash"],
tags=tags,
size=len(raw_email),
mime_type="message/rfc822",
embed_status="RAW",
message_id=parsed_email["message_id"],
subject=parsed_email["subject"],
sender=parsed_email["sender"],
recipients=parsed_email["recipients"],
sent_at=parsed_email["sent_at"],
body_raw=parsed_email["body"],
content=raw_email,
folder=folder,
)
db_session.add(mail_message)
db_session.flush()
if parsed_email["attachments"]:
processed_attachments = process_attachments(parsed_email["attachments"], mail_message)
db_session.add_all(processed_attachments)
mail_message.attachments = process_attachments(
parsed_email["attachments"], mail_message
)
db_session.add(mail_message)
return mail_message
def check_message_exists(db_session: Session, message_id: str, message_hash: bytes) -> bool:
def does_message_exist(
db_session: Session, message_id: str, message_hash: bytes
) -> bool:
"""
Check if a message already exists in the database.
Args:
db_session: Database session
message_id: Email message ID
message_hash: SHA-256 hash of message
Returns:
True if message exists, False otherwise
"""
# Check by message_id first (faster)
if message_id:
mail_message = db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first()
mail_message = (
db_session.query(MailMessage)
.filter(MailMessage.message_id == message_id)
.first()
)
if mail_message is not None:
return True
# Then check by message_hash
source_item = db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first()
source_item = (
db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first()
)
return source_item is not None
def check_message_exists(
db: Session, account_id: int, message_id: str, raw_email: str
) -> bool:
account = db.query(EmailAccount).get(account_id)
if not account:
logger.error(f"Account {account_id} not found")
return None
parsed_email = parse_email_message(raw_email, message_id)
# Use server-provided message ID if missing
if not parsed_email["message_id"]:
parsed_email["message_id"] = f"generated-{message_id}"
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
"""
Extract the UID and raw email data from the message data.
"""
uid_pattern = re.compile(r'UID (\d+)')
uid_match = uid_pattern.search(msg_data[0][0].decode('utf-8', errors='replace'))
uid_pattern = re.compile(r"UID (\d+)")
uid_match = uid_pattern.search(msg_data[0][0].decode("utf-8", errors="replace"))
uid = uid_match.group(1) if uid_match else None
raw_email = msg_data[0][1]
return uid, raw_email
@ -354,11 +195,11 @@ def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
try:
status, msg_data = conn.fetch(uid, '(UID RFC822)')
if status != 'OK' or not msg_data or not msg_data[0]:
status, msg_data = conn.fetch(uid, "(UID RFC822)")
if status != "OK" or not msg_data or not msg_data[0]:
logger.error(f"Error fetching message {uid}")
return None
return extract_email_uid(msg_data)
except Exception as e:
logger.error(f"Error processing message {uid}: {str(e)}")
@ -368,38 +209,38 @@ def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
def fetch_email_since(
conn: imaplib.IMAP4_SSL,
folder: str,
since_date: datetime = datetime(1970, 1, 1)
since_date: datetime = datetime(1970, 1, 1),
) -> list[RawEmailResponse]:
"""
Fetch emails from a folder since a given date.
Fetch emails from a folder since a given date and time.
Args:
conn: IMAP connection
folder: Folder name to select
since_date: Fetch emails since this date
since_date: Fetch emails since this date and time
Returns:
List of tuples with (uid, raw_email)
"""
try:
status, counts = conn.select(folder)
if status != 'OK':
if status != "OK":
logger.error(f"Error selecting folder {folder}: {counts}")
return []
date_str = since_date.strftime("%d-%b-%Y")
status, data = conn.search(None, f'(SINCE "{date_str}")')
if status != 'OK':
if status != "OK":
logger.error(f"Error searching folder {folder}: {data}")
return []
except Exception as e:
logger.error(f"Error in fetch_email_since for folder {folder}: {str(e)}")
return []
if not data or not data[0]:
return []
return [email for uid in data[0].split() if (email := fetch_email(conn, uid))]
@ -412,13 +253,14 @@ def process_folder(
) -> dict:
"""
Process a single folder from an email account.
Args:
conn: Active IMAP connection
folder: Folder name to process
account: Email account configuration
since_date: Only fetch messages newer than this date
processor: Function to process each message
Returns:
Stats dictionary for the folder
"""
@ -427,21 +269,21 @@ def process_folder(
try:
emails = fetch_email_since(conn, folder, since_date)
for uid, raw_email in emails:
try:
task = processor(
account_id=account.id,
message_id=uid,
folder=folder,
raw_email=raw_email.decode('utf-8', errors='replace')
raw_email=raw_email.decode("utf-8", errors="replace"),
)
if task:
new_messages += 1
except Exception as e:
logger.error(f"Error queuing message {uid}: {str(e)}")
errors += 1
except Exception as e:
logger.error(f"Error processing folder {folder}: {str(e)}")
errors += 1
@ -449,16 +291,13 @@ def process_folder(
return {
"messages_found": len(emails),
"new_messages": new_messages,
"errors": errors
"errors": errors,
}
@contextmanager
def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]:
conn = imaplib.IMAP4_SSL(
host=account.imap_server,
port=account.imap_port
)
conn = imaplib.IMAP4_SSL(host=account.imap_server, port=account.imap_port)
try:
conn.login(account.username, account.password)
yield conn
@ -470,43 +309,56 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
def vectorize_email(email: MailMessage) -> list[float]:
def vectorize_email(email: MailMessage):
qdrant_client = qdrant.get_qdrant_client()
_, chunks = embedding.embed(
"text/plain", email.body_raw, metadata=email.as_payload(),
"text/plain",
email.body,
metadata=email.as_payload(),
)
vector_ids, vectors, metadata = zip(*chunks)
qdrant.upsert_vectors(
client=qdrant_client,
collection_name="mail",
ids=vector_ids,
vectors=vectors,
payloads=metadata,
)
vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
email.chunks = chunks
if chunks:
vector_ids = [c.id for c in chunks]
vectors = [c.vector for c in chunks]
metadata = [c.item_metadata for c in chunks]
qdrant.upsert_vectors(
client=qdrant_client,
collection_name="mail",
ids=vector_ids,
vectors=vectors,
payloads=metadata,
)
embeds = defaultdict(list)
for attachment in email.attachments:
if attachment.file_path:
content = pathlib.Path(attachment.file_path).read_bytes()
if attachment.filename:
content = pathlib.Path(attachment.filename).read_bytes()
else:
content = attachment.content
collection, chunks = embedding.embed(attachment.content_type, content, metadata=attachment.as_payload())
ids, vectors, metadata = zip(*chunks)
attachment.source.vector_ids = ids
collection, chunks = embedding.embed(
attachment.mime_type, content, metadata=attachment.as_payload()
)
if not chunks:
continue
attachment.chunks = chunks
embeds[collection].extend(chunks)
for collection, chunks in embeds.items():
ids, vectors, payloads = zip(*chunks)
ids = [c.id for c in chunks]
vectors = [c.vector for c in chunks]
metadata = [c.item_metadata for c in chunks]
qdrant.upsert_vectors(
client=qdrant_client,
collection_name=collection,
ids=ids,
vectors=vectors,
payloads=payloads,
payloads=metadata,
)
vector_ids.extend([f"{collection}/{vector_id}" for vector_id in ids])
email.embed_status = "STORED"
for attachment in email.attachments:
attachment.embed_status = "STORED"
logger.info(f"Stored embedding for message {email.message_id}")
return vector_ids

View File

@ -1,14 +1,11 @@
from celery.schedules import schedule
from memory.workers.tasks.email import SYNC_ALL_ACCOUNTS
from memory.common import settings
from memory.workers.celery_app import app
from memory.common import settings
@app.on_after_configure.connect
def register_mail_schedules(sender, **_):
sender.add_periodic_task(
schedule=schedule(settings.EMAIL_SYNC_INTERVAL),
sig=app.signature(SYNC_ALL_ACCOUNTS),
name="sync-mail-all",
options={"queue": "email"},
)
app.conf.beat_schedule = {
'sync-mail-all': {
'task': 'memory.workers.tasks.email.sync_all_accounts',
'schedule': settings.EMAIL_SYNC_INTERVAL,
},
}

View File

@ -1,8 +1,8 @@
"""
Import sub-modules so Celery can register their @app.task decorators.
"""
from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa
from memory.workers.tasks import docs, email # noqa
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
__all__ = ["text", "photo", "ocr", "git", "rss", "docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"]
__all__ = ["docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"]

View File

@ -6,11 +6,8 @@ from memory.common.db.models import EmailAccount
from memory.workers.celery_app import app
from memory.workers.email import (
check_message_exists,
compute_message_hash,
create_mail_message,
create_source_item,
imap_connection,
parse_email_message,
process_folder,
vectorize_email,
)
@ -25,17 +22,20 @@ SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
@app.task(name=PROCESS_EMAIL)
def process_message(
account_id: int, message_id: str, folder: str, raw_email: str,
account_id: int,
message_id: str,
folder: str,
raw_email: str,
) -> int | None:
"""
Process a single email message and store it in the database.
Args:
account_id: ID of the EmailAccount
message_id: UID of the message on the server
folder: Folder name where the message is stored
raw_email: Raw email content as string
Returns:
source_id if successful, None otherwise
"""
@ -45,88 +45,86 @@ def process_message(
return None
with make_session() as db:
if check_message_exists(db, account_id, message_id, raw_email):
logger.debug(f"Message {message_id} already exists in database")
return None
account = db.query(EmailAccount).get(account_id)
if not account:
logger.error(f"Account {account_id} not found")
return None
parsed_email = parse_email_message(raw_email)
# Use server-provided message ID if missing
if not parsed_email["message_id"]:
parsed_email["message_id"] = f"generated-{message_id}"
message_hash = compute_message_hash(
parsed_email["message_id"],
parsed_email["subject"],
parsed_email["sender"],
parsed_email["body"]
mail_message = create_mail_message(
db, account.tags, folder, raw_email, message_id
)
if check_message_exists(db, parsed_email["message_id"], message_hash):
logger.debug(f"Message {parsed_email['message_id']} already exists in database")
return None
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
mail_message = create_mail_message(db, source_item, parsed_email, folder)
source_item.vector_ids = vectorize_email(mail_message)
vectorize_email(mail_message)
db.commit()
logger.info(f"Stored embedding for message {parsed_email['message_id']}")
logger.info("Vector IDs:")
for vector_id in source_item.vector_ids:
logger.info(f" - {vector_id}")
return source_item.id
logger.info(f"Stored embedding for message {mail_message.message_id}")
logger.info("Chunks:")
for chunk in mail_message.chunks:
logger.info(f" - {chunk.id}")
return mail_message.id
@app.task(name=SYNC_ACCOUNT)
def sync_account(account_id: int) -> dict:
"""
Synchronize emails from a specific account.
Args:
account_id: ID of the EmailAccount to sync
Returns:
dict with stats about the sync operation
"""
logger.info(f"Syncing account {account_id}")
with make_session() as db:
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
if not account or not account.active:
logger.warning(f"Account {account_id} not found or inactive")
return {"error": "Account not found or inactive"}
folders_to_process = account.folders or ["INBOX"]
since_date = account.last_sync_at or datetime(1970, 1, 1)
messages_found = 0
new_messages = 0
errors = 0
def process_message_wrapper(
account_id: int, message_id: str, folder: str, raw_email: str
) -> int | None:
if check_message_exists(db, account_id, message_id, raw_email):
return None
return process_message.delay(account_id, message_id, folder, raw_email)
try:
with imap_connection(account) as conn:
for folder in folders_to_process:
folder_stats = process_folder(conn, folder, account, since_date, process_message.delay)
folder_stats = process_folder(
conn, folder, account, since_date, process_message_wrapper
)
messages_found += folder_stats["messages_found"]
new_messages += folder_stats["new_messages"]
errors += folder_stats["errors"]
account.last_sync_at = datetime.now()
db.commit()
except Exception as e:
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
return {"error": str(e)}
return {
"account": account.email_address,
"folders_processed": len(folders_to_process),
"messages_found": messages_found,
"new_messages": new_messages,
"errors": errors
"errors": errors,
}
@ -134,18 +132,18 @@ def sync_account(account_id: int) -> dict:
def sync_all_accounts() -> list[dict]:
"""
Synchronize all active email accounts.
Returns:
List of task IDs that were scheduled
"""
with make_session() as db:
active_accounts = db.query(EmailAccount).filter(EmailAccount.active).all()
return [
{
"account_id": account.id,
"email": account.email_address,
"task_id": sync_account.delay(account.id).id
"task_id": sync_account.delay(account.id).id,
}
for account in active_accounts
]
]

View File

@ -1,5 +0,0 @@
from memory.workers.celery_app import app
@app.task(name="kb.text.ping")
def ping():
return "pong"

View File

@ -1,5 +0,0 @@
from memory.workers.celery_app import app
@app.task(name="kb.text.ping")
def ping():
return "pong"

View File

@ -1,5 +0,0 @@
from memory.workers.celery_app import app
@app.task(name="kb.text.ping")
def ping():
return "pong"

View File

@ -1,6 +0,0 @@
from memory.workers.celery_app import app
@app.task(name="kb.text.ping")
def ping():
return "pong"

View File

@ -1,5 +0,0 @@
from memory.workers.celery_app import app
@app.task(name="memory.text.ping")
def ping():
return "pong"

View File

@ -0,0 +1,264 @@
import email
import email.mime.multipart
import email.mime.text
import email.mime.base
from datetime import datetime
from email.utils import formatdate
from unittest.mock import ANY, patch
import pytest
from memory.common.parsers.email import (
compute_message_hash,
extract_attachments,
extract_body,
extract_date,
extract_recipients,
parse_email_message,
)
# Use a simple counter to generate unique message IDs without calling make_msgid
_msg_id_counter = 0
def _generate_test_message_id():
"""Generate a simple message ID for testing without expensive calls"""
global _msg_id_counter
_msg_id_counter += 1
return f"<test-message-{_msg_id_counter}@example.com>"
def create_email_message(
subject="Test Subject",
from_addr="sender@example.com",
to_addrs="recipient@example.com",
cc_addrs=None,
bcc_addrs=None,
date=None,
body="Test body content",
attachments=None,
multipart=True,
message_id=None,
):
"""Helper function to create email.message.Message objects for testing"""
if multipart:
msg = email.mime.multipart.MIMEMultipart()
msg.attach(email.mime.text.MIMEText(body))
if attachments:
for attachment in attachments:
attachment_part = email.mime.base.MIMEBase(
"application", "octet-stream"
)
attachment_part.set_payload(attachment["content"])
attachment_part.add_header(
"Content-Disposition",
f"attachment; filename={attachment['filename']}",
)
msg.attach(attachment_part)
else:
msg = email.mime.text.MIMEText(body)
msg["Subject"] = subject
msg["From"] = from_addr
msg["To"] = to_addrs
if cc_addrs:
msg["Cc"] = cc_addrs
if bcc_addrs:
msg["Bcc"] = bcc_addrs
if date:
msg["Date"] = formatdate(float(date.timestamp()))
if message_id:
msg["Message-ID"] = message_id
else:
msg["Message-ID"] = _generate_test_message_id()
return msg
@pytest.mark.parametrize(
"to_addr, cc_addr, bcc_addr, expected",
[
# Single recipient in To field
("recipient@example.com", None, None, ["recipient@example.com"]),
# Multiple recipients in To field
(
"recipient1@example.com, recipient2@example.com",
None,
None,
["recipient1@example.com", "recipient2@example.com"],
),
# To, Cc fields
(
"recipient@example.com",
"cc@example.com",
None,
["recipient@example.com", "cc@example.com"],
),
# To, Cc, Bcc fields
(
"recipient@example.com",
"cc@example.com",
"bcc@example.com",
["recipient@example.com", "cc@example.com", "bcc@example.com"],
),
# Empty fields
("", "", "", []),
],
)
def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected):
msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr)
assert sorted(extract_recipients(msg)) == sorted(expected)
def test_extract_date_missing():
msg = create_email_message(date=None)
assert extract_date(msg) is None
@pytest.mark.parametrize(
"date_str",
[
"Invalid Date Format",
"2023-01-01", # ISO format but not RFC compliant
"Monday, Jan 1, 2023", # Descriptive but not RFC compliant
"01/01/2023", # Common format but not RFC compliant
"", # Empty string
],
)
def test_extract_date_invalid_formats(date_str):
msg = create_email_message()
msg["Date"] = date_str
assert extract_date(msg) is None
@pytest.mark.parametrize(
"date_str",
[
"Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format
"01 Jan 2023 12:00:00 +0000", # RFC 822 format
"Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name
],
)
def test_extract_date(date_str):
msg = create_email_message()
msg["Date"] = date_str
result = extract_date(msg)
assert result is not None
assert result.year == 2023
assert result.month == 1
assert result.day == 1
@pytest.mark.parametrize("multipart", [True, False])
def test_extract_body_text_plain(multipart):
body_content = "This is a test email body"
msg = create_email_message(body=body_content, multipart=multipart)
extracted = extract_body(msg)
# Strip newlines for comparison since multipart emails often add them
assert extracted.strip() == body_content.strip()
def test_extract_body_with_attachments():
body_content = "This is a test email body"
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
msg = create_email_message(body=body_content, attachments=attachments)
assert body_content in extract_body(msg)
def test_extract_attachments_none():
msg = create_email_message(multipart=True)
assert extract_attachments(msg) == []
def test_extract_attachments_with_files():
attachments = [
{"filename": "test1.txt", "content": b"content1"},
{"filename": "test2.pdf", "content": b"content2"},
]
msg = create_email_message(attachments=attachments)
result = extract_attachments(msg)
assert len(result) == 2
assert result[0]["filename"] == "test1.txt"
assert result[1]["filename"] == "test2.pdf"
def test_extract_attachments_non_multipart():
msg = create_email_message(multipart=False)
assert extract_attachments(msg) == []
@pytest.mark.parametrize(
"msg_id, subject, sender, body, expected",
[
(
"<test@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xf2\xbd", # First two bytes of the actual hash
),
(
"<different@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xa4\x15", # Will be different from the first hash
),
],
)
def test_compute_message_hash(msg_id, subject, sender, body, expected):
result = compute_message_hash(msg_id, subject, sender, body)
# Verify it's bytes and correct length for SHA-256 (32 bytes)
assert isinstance(result, bytes)
assert len(result) == 32
# Verify first two bytes match expected
assert result[:2] == expected
def test_hash_consistency():
args = ("<test@example.com>", "Test Subject", "sender@example.com", "Test body")
assert compute_message_hash(*args) == compute_message_hash(*args)
def test_parse_simple_email():
test_date = datetime(2023, 1, 1, 12, 0, 0)
msg_id = "<test123@example.com>"
msg = create_email_message(
subject="Test Subject",
from_addr="sender@example.com",
to_addrs="recipient@example.com",
date=test_date,
body="Test body content",
message_id=msg_id,
)
result = parse_email_message(msg.as_string(), msg_id)
assert result == {
"message_id": msg_id,
"subject": "Test Subject",
"sender": "sender@example.com",
"recipients": ["recipient@example.com"],
"body": "Test body content\n",
"attachments": [],
"sent_at": ANY,
"hash": b'\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87'
+ b'\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1',
}
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
def test_parse_email_with_attachments():
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
msg = create_email_message(attachments=attachments)
result = parse_email_message(msg.as_string(), "123")
assert len(result["attachments"]) == 1
assert result["attachments"][0]["filename"] == "test.txt"

View File

@ -0,0 +1,285 @@
import pytest
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
@pytest.mark.parametrize(
"text, expected",
[
("", []),
("hello", ["hello"]),
("This is a simple sentence", ["This is a simple sentence"]),
("word1 word2", ["word1 word2"]),
(" ", []), # Just spaces
("\n\t ", []), # Whitespace characters
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
]
)
def test_yield_word_chunk_basic_behavior(text, expected):
"""Test basic behavior of yield_word_chunks with various inputs"""
assert list(yield_word_chunks(text)) == expected
@pytest.mark.parametrize(
"text, expected",
[
(
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
),
(
"supercalifragilisticexpialidocious",
["supercalifragilisticexpialidocious"],
)
]
)
def test_yield_word_chunk_long_text(text, expected):
"""Test chunking with long text that exceeds token limits"""
assert list(yield_word_chunks(text, max_tokens=10)) == expected
def test_yield_word_chunk_single_long_word():
"""Test behavior with a single word longer than the token limit"""
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
chunks = list(yield_word_chunks(long_word, max_tokens))
# With our changes, this should be a single chunk
assert len(chunks) == 1
assert chunks[0] == long_word
def test_yield_word_chunk_small_token_limit():
"""Test with a very small max_tokens value to force chunking"""
text = "one two three four five"
max_tokens = 1 # Very small to force chunking after each word
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
@pytest.mark.parametrize(
"text, max_tokens, expected_chunks",
[
# Empty text
("", 10, []),
# Text below token limit
("hello world", 10, ["hello world"]),
# Text right at token limit
(
"word1 word2", # 11 chars with space
3, # 12 chars limit
["word1 word2"]
),
# Text just over token limit should split
(
"word1 word2 word3", # 17 chars with spaces
4, # 16 chars limit
["word1 word2 word3"]
),
# Each word exactly at token limit
(
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
1, # 1 token limit (4 chars)
["aaaa", "bbbb", "cccc"]
),
]
)
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
"""Test different combinations of text and token limits"""
assert list(yield_word_chunks(text, max_tokens)) == expected_chunks
def test_yield_word_chunk_real_world_example():
"""Test with a realistic text example"""
text = (
"The yield_word_chunks function splits text into chunks based on word boundaries. "
"It tries to maximize chunk size while staying under the specified token limit. "
"This behavior is essential for processing large documents efficiently."
)
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
assert list(yield_word_chunks(text, max_tokens)) == [
'The yield_word_chunks function splits text',
'into chunks based on word boundaries. It',
'tries to maximize chunk size while staying',
'under the specified token limit. This',
'behavior is essential for processing large',
'documents efficiently.',
]
# Tests for yield_spans function
@pytest.mark.parametrize(
"text, expected",
[
("", []), # Empty text should yield nothing
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
(" ", []), # Just whitespace
]
)
def test_yield_spans_basic_behavior(text, expected):
"""Test basic behavior of yield_spans with various inputs"""
assert list(yield_spans(text)) == expected
def test_yield_spans_paragraphs():
"""Test splitting by paragraphs"""
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
expected = ["Paragraph one.", "Paragraph two.", "Paragraph three."]
assert list(yield_spans(text)) == expected
def test_yield_spans_sentences():
"""Test splitting by sentences when paragraphs exceed token limit"""
# Create a paragraph that exceeds token limit but sentences are within limit
max_tokens = 5 # 20 chars with CHARS_PER_TOKEN = 4
sentence1 = "Short sentence one." # ~20 chars
sentence2 = "Another short sentence." # ~24 chars
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
# Function should now preserve punctuation
expected = ["Short sentence one.", "Another short sentence."]
assert list(yield_spans(text, max_tokens)) == expected
def test_yield_spans_words():
"""Test splitting by words when sentences exceed token limit"""
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
long_sentence = "This sentence has several words and needs word-level chunking."
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
def test_yield_spans_complex_document():
"""Test with a document containing multiple paragraphs and sentences"""
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
text = (
"Paragraph one with a short sentence. And another sentence that should be split.\n\n"
"Paragraph two is also here. It has multiple sentences. Some are short. "
"This one is longer and might need word splitting depending on the limit.\n\n"
"Final short paragraph."
)
assert list(yield_spans(text, max_tokens)) == [
"Paragraph one with a short sentence.",
"And another sentence that should be split.",
"Paragraph two is also here.",
"It has multiple sentences.",
"Some are short.",
"This one is longer and might need word",
"splitting depending on the limit.",
"Final short paragraph."
]
def test_yield_spans_very_long_word():
"""Test with a word that exceeds the token limit"""
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
assert list(yield_spans(long_word, max_tokens)) == [long_word]
def test_yield_spans_with_punctuation():
"""Test sentence splitting with various punctuation"""
text = "First sentence! Second sentence? Third sentence."
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
def test_yield_spans_edge_cases():
"""Test edge cases like empty paragraphs, single character paragraphs"""
text = "\n\nA\n\n\n\nB\n\n"
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
@pytest.mark.parametrize(
"text, expected",
[
("", []), # Empty text
("Short text", ["Short text"]), # Text below token limit
(" ", []), # Just whitespace
]
)
def test_chunk_text_basic_behavior(text, expected):
"""Test basic behavior of chunk_text with various inputs"""
assert list(chunk_text(text)) == expected
def test_chunk_text_single_paragraph():
"""Test chunking a single paragraph that fits within token limit"""
text = "This is a simple paragraph that should fit in one chunk."
assert list(chunk_text(text, max_tokens=20)) == [text]
def test_chunk_text_multi_paragraph():
"""Test chunking multiple paragraphs"""
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
assert list(chunk_text(text, max_tokens=20)) == [text]
def test_chunk_text_long_text():
"""Test chunking with long text that exceeds token limit"""
# Create a long text that will need multiple chunks
sentences = [f"This is sentence {i:02}." for i in range(50)]
text = " ".join(sentences)
max_tokens = 10 # 10 tokens = ~40 chars
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
] + [
'This is sentence 49.'
]
def test_chunk_text_with_overlap():
"""Test chunking with overlap between chunks"""
# Create text with distinct parts to test overlap
text = "Part A. Part B. Part C. Part D. Part E."
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
def test_chunk_text_zero_overlap():
"""Test chunking with zero overlap"""
text = "Part A. Part B. Part C. Part D. Part E."
# 2 tokens = ~8 chars
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
def test_chunk_text_clean_break():
"""Test that chunking attempts to break at sentence boundaries"""
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
max_tokens = 5 # Enough for about 2 sentences
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
def test_chunk_text_very_long_sentences():
"""Test with very long sentences that exceed the token limit"""
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
max_tokens = 5 # Small limit to force splitting
assert list(chunk_text(text, max_tokens=max_tokens)) == [
'This is a very long sentence with many many',
'words that will definitely exceed the',
'token limit we set for',
'this particular test',
'case and should be split into multiple',
'chunks by the function.',
]
@pytest.mark.parametrize(
"string, expected_count",
[
("", 0),
("a" * CHARS_PER_TOKEN, 1),
("a" * (CHARS_PER_TOKEN * 2), 2),
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
]
)
def test_approx_token_count(string, expected_count):
assert approx_token_count(string) == expected_count

View File

@ -1,321 +1,51 @@
import uuid
import pytest
from unittest.mock import Mock, patch
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count, get_modality, embed_text, embed_file, embed_mixed, embed_page, embed
from PIL import Image
import pathlib
from memory.common import settings
from memory.common.embedding import (
get_modality,
embed_text,
embed_file,
embed_mixed,
embed_page,
embed,
write_to_file,
make_chunk,
)
@pytest.fixture
def mock_embed(mock_voyage_client):
vectors = ([i] for i in range(1000))
def embed(texts, model):
def embed(texts, model, input_type):
return Mock(embeddings=[next(vectors) for _ in texts])
mock_voyage_client.embed.side_effect = embed
mock_voyage_client.multimodal_embed.side_effect = embed
return mock_voyage_client
@pytest.mark.parametrize(
"text, expected",
[
("", []),
("hello", ["hello"]),
("This is a simple sentence", ["This is a simple sentence"]),
("word1 word2", ["word1 word2"]),
(" ", []), # Just spaces
("\n\t ", []), # Whitespace characters
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
]
)
def test_yield_word_chunk_basic_behavior(text, expected):
"""Test basic behavior of yield_word_chunks with various inputs"""
assert list(yield_word_chunks(text)) == expected
@pytest.mark.parametrize(
"text, expected",
[
(
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
),
(
"supercalifragilisticexpialidocious",
["supercalifragilisticexpialidocious"],
)
]
)
def test_yield_word_chunk_long_text(text, expected):
"""Test chunking with long text that exceeds token limits"""
assert list(yield_word_chunks(text, max_tokens=10)) == expected
def test_yield_word_chunk_single_long_word():
"""Test behavior with a single word longer than the token limit"""
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
chunks = list(yield_word_chunks(long_word, max_tokens))
# With our changes, this should be a single chunk
assert len(chunks) == 1
assert chunks[0] == long_word
def test_yield_word_chunk_small_token_limit():
"""Test with a very small max_tokens value to force chunking"""
text = "one two three four five"
max_tokens = 1 # Very small to force chunking after each word
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
@pytest.mark.parametrize(
"text, max_tokens, expected_chunks",
[
# Empty text
("", 10, []),
# Text below token limit
("hello world", 10, ["hello world"]),
# Text right at token limit
(
"word1 word2", # 11 chars with space
3, # 12 chars limit
["word1 word2"]
),
# Text just over token limit should split
(
"word1 word2 word3", # 17 chars with spaces
4, # 16 chars limit
["word1 word2 word3"]
),
# Each word exactly at token limit
(
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
1, # 1 token limit (4 chars)
["aaaa", "bbbb", "cccc"]
),
]
)
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
"""Test different combinations of text and token limits"""
assert list(yield_word_chunks(text, max_tokens)) == expected_chunks
def test_yield_word_chunk_real_world_example():
"""Test with a realistic text example"""
text = (
"The yield_word_chunks function splits text into chunks based on word boundaries. "
"It tries to maximize chunk size while staying under the specified token limit. "
"This behavior is essential for processing large documents efficiently."
)
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
assert list(yield_word_chunks(text, max_tokens)) == [
'The yield_word_chunks function splits text',
'into chunks based on word boundaries. It',
'tries to maximize chunk size while staying',
'under the specified token limit. This',
'behavior is essential for processing large',
'documents efficiently.',
]
# Tests for yield_spans function
@pytest.mark.parametrize(
"text, expected",
[
("", []), # Empty text should yield nothing
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
(" ", []), # Just whitespace
]
)
def test_yield_spans_basic_behavior(text, expected):
"""Test basic behavior of yield_spans with various inputs"""
assert list(yield_spans(text)) == expected
def test_yield_spans_paragraphs():
"""Test splitting by paragraphs"""
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
expected = ["Paragraph one.", "Paragraph two.", "Paragraph three."]
assert list(yield_spans(text)) == expected
def test_yield_spans_sentences():
"""Test splitting by sentences when paragraphs exceed token limit"""
# Create a paragraph that exceeds token limit but sentences are within limit
max_tokens = 5 # 20 chars with CHARS_PER_TOKEN = 4
sentence1 = "Short sentence one." # ~20 chars
sentence2 = "Another short sentence." # ~24 chars
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
# Function should now preserve punctuation
expected = ["Short sentence one.", "Another short sentence."]
assert list(yield_spans(text, max_tokens)) == expected
def test_yield_spans_words():
"""Test splitting by words when sentences exceed token limit"""
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
long_sentence = "This sentence has several words and needs word-level chunking."
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
def test_yield_spans_complex_document():
"""Test with a document containing multiple paragraphs and sentences"""
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
text = (
"Paragraph one with a short sentence. And another sentence that should be split.\n\n"
"Paragraph two is also here. It has multiple sentences. Some are short. "
"This one is longer and might need word splitting depending on the limit.\n\n"
"Final short paragraph."
)
assert list(yield_spans(text, max_tokens)) == [
"Paragraph one with a short sentence.",
"And another sentence that should be split.",
"Paragraph two is also here.",
"It has multiple sentences.",
"Some are short.",
"This one is longer and might need word",
"splitting depending on the limit.",
"Final short paragraph."
]
def test_yield_spans_very_long_word():
"""Test with a word that exceeds the token limit"""
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
assert list(yield_spans(long_word, max_tokens)) == [long_word]
def test_yield_spans_with_punctuation():
"""Test sentence splitting with various punctuation"""
text = "First sentence! Second sentence? Third sentence."
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
def test_yield_spans_edge_cases():
"""Test edge cases like empty paragraphs, single character paragraphs"""
text = "\n\nA\n\n\n\nB\n\n"
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
@pytest.mark.parametrize(
"text, expected",
[
("", []), # Empty text
("Short text", ["Short text"]), # Text below token limit
(" ", []), # Just whitespace
]
)
def test_chunk_text_basic_behavior(text, expected):
"""Test basic behavior of chunk_text with various inputs"""
assert list(chunk_text(text)) == expected
def test_chunk_text_single_paragraph():
"""Test chunking a single paragraph that fits within token limit"""
text = "This is a simple paragraph that should fit in one chunk."
assert list(chunk_text(text, max_tokens=20)) == [text]
def test_chunk_text_multi_paragraph():
"""Test chunking multiple paragraphs"""
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
assert list(chunk_text(text, max_tokens=20)) == [text]
def test_chunk_text_long_text():
"""Test chunking with long text that exceeds token limit"""
# Create a long text that will need multiple chunks
sentences = [f"This is sentence {i:02}." for i in range(50)]
text = " ".join(sentences)
max_tokens = 10 # 10 tokens = ~40 chars
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
] + [
'This is sentence 49.'
]
def test_chunk_text_with_overlap():
"""Test chunking with overlap between chunks"""
# Create text with distinct parts to test overlap
text = "Part A. Part B. Part C. Part D. Part E."
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
def test_chunk_text_zero_overlap():
"""Test chunking with zero overlap"""
text = "Part A. Part B. Part C. Part D. Part E."
# 2 tokens = ~8 chars
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
def test_chunk_text_clean_break():
"""Test that chunking attempts to break at sentence boundaries"""
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
max_tokens = 5 # Enough for about 2 sentences
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
def test_chunk_text_very_long_sentences():
"""Test with very long sentences that exceed the token limit"""
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
max_tokens = 5 # Small limit to force splitting
assert list(chunk_text(text, max_tokens=max_tokens)) == [
'This is a very long sentence with many many',
'words that will definitely exceed the',
'token limit we set for',
'this particular test',
'case and should be split into multiple',
'chunks by the function.',
]
@pytest.mark.parametrize(
"string, expected_count",
[
("", 0),
("a" * CHARS_PER_TOKEN, 1),
("a" * (CHARS_PER_TOKEN * 2), 2),
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
]
)
def test_approx_token_count(string, expected_count):
assert approx_token_count(string) == expected_count
@pytest.mark.parametrize(
"mime_type, expected_modality",
[
("text/plain", "doc"),
("text/html", "doc"),
("text/plain", "text"),
("text/html", "blog"),
("image/jpeg", "photo"),
("image/png", "photo"),
("application/pdf", "book"),
("application/pdf", "doc"),
("application/epub+zip", "book"),
("application/mobi", "book"),
("application/x-mobipocket-ebook", "book"),
("audio/mp3", "unknown"),
("video/mp4", "unknown"),
("text/something-new", "doc"), # Should match by 'text/' stem
("text/something-new", "text"), # Should match by 'text/' stem
("image/something-new", "photo"), # Should match by 'image/' stem
("custom/format", "unknown"), # No matching stem
]
],
)
def test_get_modality(mime_type, expected_modality):
assert get_modality(mime_type) == expected_modality
@ -329,13 +59,13 @@ def test_embed_text(mock_embed):
def test_embed_file(mock_embed, tmp_path):
mock_file = tmp_path / "test.txt"
mock_file.write_text("file content")
assert embed_file(mock_file) == [[0]]
def test_embed_mixed(mock_embed):
items = ["text", {"type": "image", "data": "base64"}]
assert embed_mixed(items) == [[0], [1]]
assert embed_mixed(items) == [[0]]
def test_embed_page_text_only(mock_embed):
@ -345,16 +75,159 @@ def test_embed_page_text_only(mock_embed):
def test_embed_page_mixed_content(mock_embed):
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
assert embed_page(page) == [[0], [1]]
assert embed_page(page) == [[0]]
def test_embed(mock_embed):
mime_type = "text/plain"
content = "sample content"
metadata = {"source": "test"}
with patch.object(uuid, "uuid4", return_value="id1"):
modality, vectors = embed(mime_type, content, metadata)
assert modality == "doc"
assert vectors == [('id1', [0], {'source': 'test'})]
modality, chunks = embed(mime_type, content, metadata)
assert modality == "text"
assert [
{
"id": c.id,
"file_path": c.file_path,
"content": c.content,
"embedding_model": c.embedding_model,
"vector": c.vector,
"item_metadata": c.item_metadata,
}
for c in chunks
] == [
{
"content": "sample content",
"embedding_model": "voyage-3-large",
"file_path": None,
"id": "id1",
"item_metadata": {"source": "test"},
"vector": [0],
},
]
def test_write_to_file_text(mock_file_storage):
"""Test writing a string to a file."""
chunk_id = "test-chunk-id"
content = "This is a test string"
file_path = write_to_file(chunk_id, content)
assert file_path == settings.FILE_STORAGE_DIR / f"{chunk_id}.txt"
assert file_path.exists()
assert file_path.read_text() == content
def test_write_to_file_bytes(mock_file_storage):
"""Test writing bytes to a file."""
chunk_id = "test-chunk-id"
content = b"These are test bytes"
file_path = write_to_file(chunk_id, content)
assert file_path == settings.FILE_STORAGE_DIR / f"{chunk_id}.bin"
assert file_path.exists()
assert file_path.read_bytes() == content
def test_write_to_file_image(mock_file_storage):
"""Test writing an image to a file."""
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
chunk_id = "test-chunk-id"
file_path = write_to_file(chunk_id, img)
assert file_path == settings.FILE_STORAGE_DIR / f"{chunk_id}.png"
assert file_path.exists()
# Verify it's a valid image file by opening it
image = Image.open(file_path)
assert image.size == (100, 100)
def test_write_to_file_unsupported_type(mock_file_storage):
"""Test that an error is raised for unsupported content types."""
chunk_id = "test-chunk-id"
content = 123 # Integer is not a supported type
with pytest.raises(ValueError, match="Unsupported content type"):
write_to_file(chunk_id, content)
def test_make_chunk_text_only(mock_file_storage, db_session):
"""Test creating a chunk from string content."""
page = {
"contents": ["text content 1", "text content 2"],
"metadata": {"source": "test"},
}
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
):
chunk = make_chunk(page, vector, metadata)
assert chunk.id == "00000000-0000-0000-0000-000000000001"
assert chunk.content == "text content 1\n\ntext content 2"
assert chunk.file_path is None
assert chunk.embedding_model == settings.TEXT_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata
def test_make_chunk_single_image(mock_file_storage, db_session):
"""Test creating a chunk from a single image."""
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
page = {"contents": [img], "metadata": {"source": "test"}}
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
):
chunk = make_chunk(page, vector, metadata)
assert chunk.id == "00000000-0000-0000-0000-000000000002"
assert chunk.content is None
assert chunk.file_path == (
settings.FILE_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
)
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata
# Verify the file exists
assert pathlib.Path(chunk.file_path[0]).exists()
def test_make_chunk_mixed_content(mock_file_storage, db_session):
"""Test creating a chunk from mixed content (string and image)."""
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
page = {"contents": ["text content", img], "metadata": {"source": "test"}}
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
):
chunk = make_chunk(page, vector, metadata)
assert chunk.id == "00000000-0000-0000-0000-000000000003"
assert chunk.content is None
assert chunk.file_path == (
settings.FILE_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
)
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata
# Verify the files exist
assert (
settings.FILE_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_0.txt"
).exists()
assert (
settings.FILE_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png"
).exists()

View File

@ -70,7 +70,7 @@ def test_extract_image_with_path(tmp_path):
img.save(img_path)
page, = extract_image(img_path)
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
assert page["metadata"] == {}
@ -81,7 +81,7 @@ def test_extract_image_with_bytes():
img_bytes = buffer.getvalue()
page, = extract_image(img_bytes)
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
assert page["metadata"] == {}
@ -119,12 +119,11 @@ def test_extract_content_image(tmp_path):
img_path = tmp_path / "test_img.png"
img.save(img_path)
result = extract_content("image/png", img_path)
result, = extract_content("image/png", img_path)
assert len(result) == 1
assert isinstance(result[0]["contents"], Image.Image)
assert result[0]["contents"].size == (100, 100)
assert result[0]["metadata"] == {}
assert isinstance(result["contents"][0], Image.Image)
assert result["contents"][0].size == (100, 100)
assert result["metadata"] == {}
def test_extract_content_unsupported_type():

View File

@ -2,7 +2,12 @@ from unittest import mock
import pytest
from datetime import datetime, timedelta
from unittest.mock import patch
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common.db.models import (
EmailAccount,
MailMessage,
SourceItem,
EmailAttachment,
)
from memory.common import embedding
from memory.workers.tasks.email import process_message
@ -57,7 +62,7 @@ def test_email_account(db_session):
use_ssl=True,
folders=["INBOX", "Sent", "Archive"],
tags=["test", "integration"],
active=True
active=True,
)
db_session.add(account)
db_session.commit()
@ -66,62 +71,58 @@ def test_email_account(db_session):
def test_process_simple_email(db_session, test_email_account, qdrant):
"""Test processing a simple email message."""
source_id = process_message(
mail_message_id = process_message(
account_id=test_email_account.id,
message_id="101",
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
assert source_id is not None
# Check that the source item was created
source_item = db_session.query(SourceItem).filter(SourceItem.id == source_id).first()
assert source_item is not None
assert source_item.modality == "mail"
assert source_item.tags == test_email_account.tags
assert source_item.mime_type == "message/rfc822"
assert source_item.embed_status == "RAW"
# Check that the mail message was created and linked to the source
mail_message = db_session.query(MailMessage).filter(MailMessage.source_id == source_id).first()
mail_message = (
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
)
assert mail_message is not None
assert mail_message.modality == "mail"
assert mail_message.tags == test_email_account.tags
assert mail_message.mime_type == "message/rfc822"
assert mail_message.embed_status == "STORED"
assert mail_message.subject == "Test Email 1"
assert mail_message.sender == "alice@example.com"
assert "bob@example.com" in mail_message.recipients
assert "This is test email 1" in mail_message.body_raw
assert "This is test email 1" in mail_message.content
assert mail_message.folder == "INBOX"
def test_process_email_with_attachment(db_session, test_email_account, qdrant):
"""Test processing a message with an attachment."""
source_id = process_message(
mail_message_id = process_message(
account_id=test_email_account.id,
message_id="302",
folder="Archive",
raw_email=EMAIL_WITH_ATTACHMENT_RAW,
)
assert source_id is not None
# Check mail message specifics
mail_message = db_session.query(MailMessage).filter(MailMessage.source_id == source_id).first()
mail_message = (
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
)
assert mail_message is not None
assert mail_message.subject == "Email with Attachment"
assert mail_message.sender == "eve@example.com"
assert "This email has an attachment" in mail_message.body_raw
assert "This email has an attachment" in mail_message.content
assert mail_message.folder == "Archive"
# Check attachments were processed and stored in the EmailAttachment table
attachments = db_session.query(EmailAttachment).filter(
EmailAttachment.mail_message_id == mail_message.id
).all()
assert len(attachments) > 0
assert attachments[0].filename == "test.txt"
assert attachments[0].content_type == "text/plain"
# Either content or file_path should be set
assert attachments[0].content is not None or attachments[0].file_path is not None
attachments = (
db_session.query(
EmailAttachment.filename,
EmailAttachment.content,
EmailAttachment.mime_type,
)
.filter(EmailAttachment.mail_message_id == mail_message.id)
.all()
)
assert attachments == [(None, "This is a test attachment", "text/plain")]
def test_process_empty_message(db_session, test_email_account, qdrant):
@ -132,7 +133,7 @@ def test_process_empty_message(db_session, test_email_account, qdrant):
folder="Archive",
raw_email="",
)
assert source_id is None
@ -145,13 +146,13 @@ def test_process_duplicate_message(db_session, test_email_account, qdrant):
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
assert source_id_1 is not None, "First call should return a source_id"
# Count records to verify state before second call
source_count_before = db_session.query(SourceItem).count()
message_count_before = db_session.query(MailMessage).count()
# Second call with same email should detect duplicate and return None
source_id_2 = process_message(
account_id=test_email_account.id,
@ -159,12 +160,16 @@ def test_process_duplicate_message(db_session, test_email_account, qdrant):
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
assert source_id_2 is None, "Second call should return None for duplicate message"
# Verify no new records were created
source_count_after = db_session.query(SourceItem).count()
message_count_after = db_session.query(MailMessage).count()
assert source_count_before == source_count_after, "No new SourceItem should be created"
assert message_count_before == message_count_after, "No new MailMessage should be created"
assert source_count_before == source_count_after, (
"No new SourceItem should be created"
)
assert message_count_before == message_count_after, (
"No new MailMessage should be created"
)

View File

@ -9,19 +9,16 @@ from datetime import datetime
from email.utils import formatdate
from unittest.mock import ANY, MagicMock, patch
import pytest
from memory.common.db.models import SourceItem, MailMessage, EmailAttachment, EmailAccount
from memory.common.db.models import (
SourceItem,
MailMessage,
EmailAttachment,
EmailAccount,
)
from memory.common import settings
from memory.common import embedding
from memory.workers.email import (
compute_message_hash,
create_source_item,
extract_attachments,
extract_body,
extract_date,
extract_email_uid,
extract_recipients,
parse_email_message,
check_message_exists,
create_mail_message,
fetch_email,
fetch_email_since,
@ -35,6 +32,7 @@ from memory.workers.email import (
@pytest.fixture
def mock_uuid4():
i = 0
def uuid4():
nonlocal i
i += 1
@ -44,181 +42,6 @@ def mock_uuid4():
yield
# Use a simple counter to generate unique message IDs without calling make_msgid
_msg_id_counter = 0
def _generate_test_message_id():
"""Generate a simple message ID for testing without expensive calls"""
global _msg_id_counter
_msg_id_counter += 1
return f"<test-message-{_msg_id_counter}@example.com>"
def create_email_message(
subject="Test Subject",
from_addr="sender@example.com",
to_addrs="recipient@example.com",
cc_addrs=None,
bcc_addrs=None,
date=None,
body="Test body content",
attachments=None,
multipart=True,
message_id=None,
):
"""Helper function to create email.message.Message objects for testing"""
if multipart:
msg = email.mime.multipart.MIMEMultipart()
msg.attach(email.mime.text.MIMEText(body))
if attachments:
for attachment in attachments:
attachment_part = email.mime.base.MIMEBase(
"application", "octet-stream"
)
attachment_part.set_payload(attachment["content"])
attachment_part.add_header(
"Content-Disposition",
f"attachment; filename={attachment['filename']}",
)
msg.attach(attachment_part)
else:
msg = email.mime.text.MIMEText(body)
msg["Subject"] = subject
msg["From"] = from_addr
msg["To"] = to_addrs
if cc_addrs:
msg["Cc"] = cc_addrs
if bcc_addrs:
msg["Bcc"] = bcc_addrs
if date:
msg["Date"] = formatdate(float(date.timestamp()))
if message_id:
msg["Message-ID"] = message_id
else:
msg["Message-ID"] = _generate_test_message_id()
return msg
@pytest.mark.parametrize(
"to_addr, cc_addr, bcc_addr, expected",
[
# Single recipient in To field
("recipient@example.com", None, None, ["recipient@example.com"]),
# Multiple recipients in To field
(
"recipient1@example.com, recipient2@example.com",
None,
None,
["recipient1@example.com", "recipient2@example.com"],
),
# To, Cc fields
(
"recipient@example.com",
"cc@example.com",
None,
["recipient@example.com", "cc@example.com"],
),
# To, Cc, Bcc fields
(
"recipient@example.com",
"cc@example.com",
"bcc@example.com",
["recipient@example.com", "cc@example.com", "bcc@example.com"],
),
# Empty fields
("", "", "", []),
],
)
def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected):
msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr)
assert sorted(extract_recipients(msg)) == sorted(expected)
def test_extract_date_missing():
msg = create_email_message(date=None)
assert extract_date(msg) is None
@pytest.mark.parametrize(
"date_str",
[
"Invalid Date Format",
"2023-01-01", # ISO format but not RFC compliant
"Monday, Jan 1, 2023", # Descriptive but not RFC compliant
"01/01/2023", # Common format but not RFC compliant
"", # Empty string
],
)
def test_extract_date_invalid_formats(date_str):
msg = create_email_message()
msg["Date"] = date_str
assert extract_date(msg) is None
@pytest.mark.parametrize(
"date_str",
[
"Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format
"01 Jan 2023 12:00:00 +0000", # RFC 822 format
"Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name
],
)
def test_extract_date(date_str):
msg = create_email_message()
msg["Date"] = date_str
result = extract_date(msg)
assert result is not None
assert result.year == 2023
assert result.month == 1
assert result.day == 1
@pytest.mark.parametrize("multipart", [True, False])
def test_extract_body_text_plain(multipart):
body_content = "This is a test email body"
msg = create_email_message(body=body_content, multipart=multipart)
extracted = extract_body(msg)
# Strip newlines for comparison since multipart emails often add them
assert extracted.strip() == body_content.strip()
def test_extract_body_with_attachments():
body_content = "This is a test email body"
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
msg = create_email_message(body=body_content, attachments=attachments)
assert body_content in extract_body(msg)
def test_extract_attachments_none():
msg = create_email_message(multipart=True)
assert extract_attachments(msg) == []
def test_extract_attachments_with_files():
attachments = [
{"filename": "test1.txt", "content": b"content1"},
{"filename": "test2.pdf", "content": b"content2"},
]
msg = create_email_message(attachments=attachments)
result = extract_attachments(msg)
assert len(result) == 2
assert result[0]["filename"] == "test1.txt"
assert result[1]["filename"] == "test2.pdf"
def test_extract_attachments_non_multipart():
msg = create_email_message(multipart=False)
assert extract_attachments(msg) == []
@pytest.mark.parametrize(
"attachment_size, max_inline_size, message_id",
[
@ -237,7 +60,6 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
}
message = MailMessage(
id=1,
source=SourceItem(tags=["test"]),
message_id=message_id,
sender="sender@example.com",
folder="INBOX",
@ -247,12 +69,8 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
result = process_attachment(attachment, message)
assert result is not None
# For inline attachments, content should be base64 encoded string
assert isinstance(result.content, bytes)
# Decode the base64 string and compare with the original content
decoded_content = base64.b64decode(result.content)
assert decoded_content == attachment["content"]
assert result.file_path is None
assert result.content == attachment["content"].decode("utf-8", errors="replace")
assert result.filename is None
@pytest.mark.parametrize(
@ -266,14 +84,13 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
)
def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
attachment = {
"filename": "test.txt",
"filename": "test/with:special\\chars.txt",
"content_type": "text/plain",
"size": attachment_size,
"content": b"a" * attachment_size,
}
message = MailMessage(
id=1,
source=SourceItem(tags=["test"]),
message_id=message_id,
sender="sender@example.com",
folder="INBOX",
@ -283,7 +100,12 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
assert result is not None
assert not result.content
assert result.file_path == str(settings.FILE_STORAGE_DIR / "sender@example.com" / "INBOX" / "test.txt")
assert result.filename == str(
settings.FILE_STORAGE_DIR
/ "sender_example_com"
/ "INBOX"
/ "test_with_special_chars.txt"
)
def test_process_attachment_write_error():
@ -296,7 +118,6 @@ def test_process_attachment_write_error():
}
message = MailMessage(
id=1,
source=SourceItem(tags=["test"]),
message_id="<test@example.com>",
sender="sender@example.com",
folder="INBOX",
@ -344,7 +165,7 @@ def test_process_attachments_mixed():
]
message = MailMessage(
id=1,
source=SourceItem(tags=["test"]),
tags=["test"],
message_id="<test@example.com>",
sender="sender@example.com",
folder="INBOX",
@ -357,84 +178,14 @@ def test_process_attachments_mixed():
# Verify we have all attachments processed
assert len(results) == 3
# Verify small attachments are base64 encoded
assert isinstance(results[0].content, bytes)
assert isinstance(results[2].content, bytes)
assert results[0].content == "a" * 20
assert results[2].content == "c" * 30
# Verify large attachment has a path
assert results[1].file_path is not None
@pytest.mark.parametrize(
"msg_id, subject, sender, body, expected",
[
(
"<test@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xf2\xbd", # First two bytes of the actual hash
),
(
"<different@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xa4\x15", # Will be different from the first hash
),
],
)
def test_compute_message_hash(msg_id, subject, sender, body, expected):
result = compute_message_hash(msg_id, subject, sender, body)
# Verify it's bytes and correct length for SHA-256 (32 bytes)
assert isinstance(result, bytes)
assert len(result) == 32
# Verify first two bytes match expected
assert result[:2] == expected
def test_hash_consistency():
args = ("<test@example.com>", "Test Subject", "sender@example.com", "Test body")
assert compute_message_hash(*args) == compute_message_hash(*args)
def test_parse_simple_email():
test_date = datetime(2023, 1, 1, 12, 0, 0)
msg_id = "<test123@example.com>"
msg = create_email_message(
subject="Test Subject",
from_addr="sender@example.com",
to_addrs="recipient@example.com",
date=test_date,
body="Test body content",
message_id=msg_id,
assert results[1].filename == str(
settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt"
)
result = parse_email_message(msg.as_string())
assert result == {
"message_id": msg_id,
"subject": "Test Subject",
"sender": "sender@example.com",
"recipients": ["recipient@example.com"],
"body": "Test body content\n",
"attachments": [],
"sent_at": ANY,
}
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
def test_parse_email_with_attachments():
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
msg = create_email_message(attachments=attachments)
result = parse_email_message(msg.as_string())
assert len(result["attachments"]) == 1
assert result["attachments"][0]["filename"] == "test.txt"
def test_extract_email_uid_valid():
msg_data = [(b"1 (UID 12345 RFC822 {1234}", b"raw email content")]
@ -452,148 +203,55 @@ def test_extract_email_uid_no_match():
assert raw_email == b"raw email content"
def test_create_source_item(db_session):
# Mock data
message_hash = b"test_hash_bytes" + bytes(28) # 32 bytes for SHA-256
account_tags = ["work", "important"]
raw_email_size = 1024
# Call function
source_item = create_source_item(
db_session=db_session,
message_hash=message_hash,
account_tags=account_tags,
raw_size=raw_email_size,
)
# Verify the source item was created correctly
assert isinstance(source_item, SourceItem)
assert source_item.id is not None
assert source_item.modality == "mail"
assert source_item.sha256 == message_hash
assert source_item.tags == account_tags
assert source_item.byte_length == raw_email_size
assert source_item.mime_type == "message/rfc822"
assert source_item.embed_status == "RAW"
# Verify it was added to the session
db_session.flush()
fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one()
assert fetched_item is not None
assert fetched_item.sha256 == message_hash
@pytest.mark.parametrize(
"setup_db, message_id, message_hash, expected_exists",
[
# Test by message ID
(
lambda db: (
# First create source_item to satisfy foreign key constraint
db.add(
SourceItem(
id=1,
modality="mail",
sha256=b"some_hash_bytes" + bytes(28),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
)
),
db.flush(),
# Then create mail_message
db.add(
MailMessage(
source_id=1,
message_id="<test@example.com>",
subject="Test",
sender="test@example.com",
recipients=["recipient@example.com"],
body_raw="Test body",
)
),
),
"<test@example.com>",
b"unmatched_hash",
True,
),
# Test by non-existent message ID
(lambda db: None, "<nonexistent@example.com>", b"unmatched_hash", False),
# Test by hash
(
lambda db: db.add(
SourceItem(
modality="mail",
sha256=b"test_hash_bytes" + bytes(28),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
)
),
"",
b"test_hash_bytes" + bytes(28),
True,
),
# Test by non-existent hash
(lambda db: None, "", b"different_hash_" + bytes(28), False),
],
)
def test_check_message_exists(
db_session, setup_db, message_id, message_hash, expected_exists
):
# Setup test data
if setup_db:
setup_db(db_session)
db_session.flush()
# Test the function
assert check_message_exists(db_session, message_id, message_hash) == expected_exists
def test_create_mail_message(db_session):
source_item = SourceItem(
modality="mail",
sha256=b"test_hash_bytes" + bytes(28),
tags=["test"],
byte_length=100,
raw_email = (
"From: sender@example.com\n"
"To: recipient@example.com\n"
"Subject: Test Subject\n"
"Date: Sun, 1 Jan 2023 12:00:00 +0000\n"
"Message-ID: 321\n"
"MIME-Version: 1.0\n"
'Content-Type: multipart/mixed; boundary="boundary"\n'
"\n"
"--boundary\n"
"Content-Type: text/plain\n"
"\n"
"Test body content\n"
"--boundary\n"
'Content-Disposition: attachment; filename="test.txt"\n'
"Content-Type: text/plain\n"
"Content-Transfer-Encoding: base64\n"
"\n"
"YXR0YWNobWVudCBjb250ZW50\n"
"--boundary--"
)
db_session.add(source_item)
db_session.flush()
parsed_email = {
"message_id": "<test@example.com>",
"subject": "Test Subject",
"sender": "sender@example.com",
"recipients": ["recipient@example.com"],
"sent_at": datetime(2023, 1, 1, 12, 0, 0),
"body": "Test body content",
"attachments": [
{"filename": "test.txt", "content_type": "text/plain", "size": 100}
],
}
folder = "INBOX"
# Call function
mail_message = create_mail_message(
db_session=db_session,
source_item=source_item,
parsed_email=parsed_email,
raw_email=raw_email,
folder=folder,
tags=["test"],
message_id="123",
)
db_session.commit()
attachments = db_session.query(EmailAttachment).filter(EmailAttachment.mail_message_id == mail_message.id).all()
attachments = (
db_session.query(EmailAttachment)
.filter(EmailAttachment.mail_message_id == mail_message.id)
.all()
)
# Verify the mail message was created correctly
assert isinstance(mail_message, MailMessage)
assert mail_message.source_id == source_item.id
assert mail_message.message_id == parsed_email["message_id"]
assert mail_message.subject == parsed_email["subject"]
assert mail_message.sender == parsed_email["sender"]
assert mail_message.recipients == parsed_email["recipients"]
assert mail_message.sent_at.isoformat()[:-6] == parsed_email["sent_at"].isoformat()
assert mail_message.body_raw == parsed_email["body"]
assert mail_message.message_id == "321"
assert mail_message.subject == "Test Subject"
assert mail_message.sender == "sender@example.com"
assert mail_message.recipients == ["recipient@example.com"]
assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00"
assert mail_message.content == raw_email
assert mail_message.body == "Test body content\n"
assert mail_message.attachments == attachments
@ -675,99 +333,97 @@ def test_process_folder_error(email_provider):
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
mail_message = MailMessage(
source=SourceItem(
modality="mail",
sha256=b"test_hash" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
),
sha256=b"test_hash" + bytes(24),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="RAW",
message_id="<test-vector@example.com>",
subject="Test Vectorization",
sender="sender@example.com",
recipients=["recipient@example.com"],
body_raw="This is a test email for vectorization",
content="This is a test email for vectorization",
folder="INBOX",
)
db_session.add(mail_message)
db_session.flush()
assert mail_message.embed_status == "RAW"
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
vector_ids = vectorize_email(mail_message)
assert len(vector_ids) == 1
assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
vectorize_email(mail_message)
assert [c.id for c in mail_message.chunks] == [
"00000000-0000-0000-0000-000000000001"
]
db_session.commit()
assert mail_message.embed_status == "STORED"
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
mail_message = MailMessage(
source=SourceItem(
modality="mail",
sha256=b"test_hash" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
),
sha256=b"test_hash" + bytes(24),
tags=["test"],
size=100,
mime_type="message/rfc822",
embed_status="RAW",
message_id="<test-vector-attach@example.com>",
subject="Test Vectorization with Attachments",
sender="sender@example.com",
recipients=["recipient@example.com"],
body_raw="This is a test email with attachments",
content="This is a test email with attachments",
folder="INBOX",
)
db_session.add(mail_message)
db_session.flush()
# Add two attachments - one with content and one with file_path
attachment1 = EmailAttachment(
mail_message_id=mail_message.id,
filename="inline.txt",
content_type="text/plain",
size=100,
content=base64.b64encode(b"This is inline content"),
file_path=None,
source=SourceItem(
modality="doc",
sha256=b"test_hash1" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="text/plain",
embed_status="RAW",
),
filename=None,
modality="doc",
sha256=b"test_hash1" + bytes(24),
tags=["test"],
mime_type="text/plain",
embed_status="RAW",
)
file_path = mail_message.attachments_path / "stored.txt"
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_bytes(b"This is stored content")
attachment2 = EmailAttachment(
mail_message_id=mail_message.id,
filename="stored.txt",
content_type="text/plain",
size=200,
content=None,
file_path=str(file_path),
source=SourceItem(
modality="doc",
sha256=b"test_hash2" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="text/plain",
embed_status="RAW",
),
filename=str(file_path),
modality="doc",
sha256=b"test_hash2" + bytes(24),
tags=["test"],
mime_type="text/plain",
embed_status="RAW",
)
db_session.add_all([attachment1, attachment2])
db_session.flush()
# Mock embedding functions but use real qdrant
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
# Call the function
vector_ids = vectorize_email(mail_message)
vectorize_email(mail_message)
# Verify results
assert len(vector_ids) == 3
assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
assert vector_ids[1] == "doc/00000000-0000-0000-0000-000000000002"
assert vector_ids[2] == "doc/00000000-0000-0000-0000-000000000003"
vector_ids = [
c.id for c in mail_message.chunks + attachment1.chunks + attachment2.chunks
]
assert vector_ids == [
"00000000-0000-0000-0000-000000000001",
"00000000-0000-0000-0000-000000000002",
"00000000-0000-0000-0000-000000000003",
]
db_session.commit()
assert mail_message.embed_status == "STORED"
assert attachment1.embed_status == "STORED"
assert attachment2.embed_status == "STORED"