mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-30 06:36:07 +02:00
use proper chunk objects
This commit is contained in:
parent
453aed7c19
commit
44de394eb1
@ -1,8 +1,8 @@
|
||||
"""Initial structure
|
||||
"""Initial structure for the database.
|
||||
|
||||
Revision ID: a466a07360d5
|
||||
Revises:
|
||||
Create Date: 2025-04-27 17:15:37.487616
|
||||
Revision ID: 4684845ca51e
|
||||
Revises: a466a07360d5
|
||||
Create Date: 2025-05-03 14:00:56.113840
|
||||
|
||||
"""
|
||||
|
||||
@ -13,7 +13,7 @@ import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a466a07360d5"
|
||||
revision: str = "4684845ca51e"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@ -21,12 +21,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
|
||||
|
||||
# Create enum type for github_item with IF NOT EXISTS
|
||||
op.execute(
|
||||
"DO $$ BEGIN CREATE TYPE gh_item_kind AS ENUM ('issue','pr','comment','project_card'); EXCEPTION WHEN duplicate_object THEN NULL; END $$;"
|
||||
)
|
||||
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")
|
||||
op.create_table(
|
||||
"email_accounts",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
@ -56,6 +51,22 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email_address"),
|
||||
)
|
||||
op.create_index(
|
||||
"email_accounts_active_idx",
|
||||
"email_accounts",
|
||||
["active", "last_sync_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"email_accounts_address_idx", "email_accounts", ["email_address"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
"email_accounts_tags_idx",
|
||||
"email_accounts",
|
||||
["tags"],
|
||||
unique=False,
|
||||
postgresql_using="gin",
|
||||
)
|
||||
op.create_table(
|
||||
"rss_feeds",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
@ -80,6 +91,16 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("url"),
|
||||
)
|
||||
op.create_index(
|
||||
"rss_feeds_active_idx", "rss_feeds", ["active", "last_checked_at"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"rss_feeds_tags_idx",
|
||||
"rss_feeds",
|
||||
["tags"],
|
||||
unique=False,
|
||||
postgresql_using="gin",
|
||||
)
|
||||
op.create_table(
|
||||
"source_item",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
@ -92,134 +113,103 @@ def upgrade() -> None:
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
|
||||
sa.Column("lang", sa.Text(), nullable=True),
|
||||
sa.Column("model_hash", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"vector_ids", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||
),
|
||||
sa.Column("embed_status", sa.Text(), server_default="RAW", nullable=False),
|
||||
sa.Column("byte_length", sa.Integer(), nullable=True),
|
||||
sa.Column("size", sa.Integer(), nullable=True),
|
||||
sa.Column("mime_type", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("filename", sa.Text(), nullable=True),
|
||||
sa.Column("embed_status", sa.Text(), server_default="RAW", nullable=False),
|
||||
sa.Column("type", sa.String(length=50), nullable=True),
|
||||
sa.CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("sha256"),
|
||||
)
|
||||
op.create_index("source_filename_idx", "source_item", ["filename"], unique=False)
|
||||
op.create_index("source_modality_idx", "source_item", ["modality"], unique=False)
|
||||
op.create_index("source_status_idx", "source_item", ["embed_status"], unique=False)
|
||||
op.create_index(
|
||||
"source_tags_idx", "source_item", ["tags"], unique=False, postgresql_using="gin"
|
||||
)
|
||||
op.create_table(
|
||||
"blog_post",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("url", sa.Text(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("url"),
|
||||
)
|
||||
op.create_table(
|
||||
"book_doc",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("author", sa.Text(), nullable=True),
|
||||
sa.Column("chapter", sa.Text(), nullable=True),
|
||||
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"chat_message",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("platform", sa.Text(), nullable=True),
|
||||
sa.Column("channel_id", sa.Text(), nullable=True),
|
||||
sa.Column("author", sa.Text(), nullable=True),
|
||||
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("body_raw", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"git_commit",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("repo_path", sa.Text(), nullable=True),
|
||||
sa.Column("commit_sha", sa.Text(), nullable=True),
|
||||
sa.Column("author_name", sa.Text(), nullable=True),
|
||||
sa.Column("author_email", sa.Text(), nullable=True),
|
||||
sa.Column("author_date", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("msg_raw", sa.Text(), nullable=True),
|
||||
sa.Column("diff_summary", sa.Text(), nullable=True),
|
||||
sa.Column("files_changed", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("commit_sha"),
|
||||
op.create_index(
|
||||
"chat_channel_idx", "chat_message", ["platform", "channel_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"mail_message",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
"chunk",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.UUID(),
|
||||
server_default=sa.text("uuid_generate_v4()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_id", sa.Text(), nullable=True),
|
||||
sa.Column("subject", sa.Text(), nullable=True),
|
||||
sa.Column("sender", sa.Text(), nullable=True),
|
||||
sa.Column("recipients", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("body_raw", sa.Text(), nullable=True),
|
||||
sa.Column("folder", sa.Text(), nullable=True),
|
||||
sa.Column("tsv", postgresql.TSVECTOR(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("message_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"email_attachment",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("mail_message_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("filename", sa.Text(), nullable=False),
|
||||
sa.Column("content_type", sa.Text(), nullable=True),
|
||||
sa.Column("size", sa.Integer(), nullable=True),
|
||||
sa.Column("content", postgresql.BYTEA(), nullable=True),
|
||||
sa.Column("file_path", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("embedding_model", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mail_message_id"], ["mail_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_id"], ["source_item.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"misc_doc",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("path", sa.Text(), nullable=True),
|
||||
sa.Column("mime_type", sa.Text(), nullable=True),
|
||||
sa.CheckConstraint("(file_path IS NOT NULL) OR (content IS NOT NULL)"),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("chunk_source_idx", "chunk", ["source_id"], unique=False)
|
||||
op.create_table(
|
||||
"photo",
|
||||
"git_commit",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=True),
|
||||
sa.Column("exif_taken_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("exif_lat", sa.Numeric(9, 6), nullable=True),
|
||||
sa.Column("exif_lon", sa.Numeric(9, 6), nullable=True),
|
||||
sa.Column("camera", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.Column("repo_path", sa.Text(), nullable=True),
|
||||
sa.Column("commit_sha", sa.Text(), nullable=True),
|
||||
sa.Column("author_name", sa.Text(), nullable=True),
|
||||
sa.Column("author_email", sa.Text(), nullable=True),
|
||||
sa.Column("author_date", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("diff_summary", sa.Text(), nullable=True),
|
||||
sa.Column("files_changed", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("commit_sha"),
|
||||
)
|
||||
op.create_index("git_date_idx", "git_commit", ["author_date"], unique=False)
|
||||
op.create_index(
|
||||
"git_files_idx",
|
||||
"git_commit",
|
||||
["files_changed"],
|
||||
unique=False,
|
||||
postgresql_using="gin",
|
||||
)
|
||||
|
||||
# Add github_item table
|
||||
op.create_table(
|
||||
"github_item",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("kind", sa.Text(), nullable=False),
|
||||
sa.Column("repo_path", sa.Text(), nullable=False),
|
||||
sa.Column("number", sa.Integer(), nullable=True),
|
||||
@ -227,7 +217,6 @@ def upgrade() -> None:
|
||||
sa.Column("commit_sha", sa.Text(), nullable=True),
|
||||
sa.Column("state", sa.Text(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("body_raw", sa.Text(), nullable=True),
|
||||
sa.Column("labels", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("author", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
@ -235,147 +224,125 @@ def upgrade() -> None:
|
||||
sa.Column("merged_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("diff_summary", sa.Text(), nullable=True),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.CheckConstraint("kind IN ('issue', 'pr', 'comment', 'project_card')"),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Add constraint to github_item.kind
|
||||
op.create_check_constraint(
|
||||
"github_item_kind_check",
|
||||
op.create_index(
|
||||
"gh_issue_lookup_idx",
|
||||
"github_item",
|
||||
"kind IN ('issue', 'pr', 'comment', 'project_card')",
|
||||
)
|
||||
|
||||
# Add missing constraint to source_item
|
||||
op.create_check_constraint(
|
||||
"source_item_embed_status_check",
|
||||
"source_item",
|
||||
"embed_status IN ('RAW','QUEUED','STORED','FAILED')",
|
||||
)
|
||||
|
||||
# Create trigger function for vector_ids validation
|
||||
op.execute("""
|
||||
CREATE OR REPLACE FUNCTION trg_vector_ids_not_empty()
|
||||
RETURNS TRIGGER LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.embed_status = 'STORED'
|
||||
AND (NEW.vector_ids IS NULL OR array_length(NEW.vector_ids,1) = 0) THEN
|
||||
RAISE EXCEPTION
|
||||
USING MESSAGE = 'vector_ids must not be empty when embed_status = STORED';
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
""")
|
||||
|
||||
# Create trigger
|
||||
op.execute("""
|
||||
CREATE TRIGGER check_vector_ids
|
||||
BEFORE UPDATE ON source_item
|
||||
FOR EACH ROW EXECUTE FUNCTION trg_vector_ids_not_empty();
|
||||
""")
|
||||
|
||||
# Create indexes for source_item
|
||||
op.create_index("source_modality_idx", "source_item", ["modality"])
|
||||
op.create_index("source_status_idx", "source_item", ["embed_status"])
|
||||
op.create_index("source_tags_idx", "source_item", ["tags"], postgresql_using="gin")
|
||||
|
||||
# Create indexes for mail_message
|
||||
op.create_index("mail_sent_idx", "mail_message", ["sent_at"])
|
||||
op.create_index(
|
||||
"mail_recipients_idx", "mail_message", ["recipients"], postgresql_using="gin"
|
||||
)
|
||||
op.create_index("email_attachment_filename_idx", "email_attachment", ["filename"], unique=False)
|
||||
op.create_index("email_attachment_message_idx", "email_attachment", ["mail_message_id"], unique=False)
|
||||
op.create_index("mail_tsv_idx", "mail_message", ["tsv"], postgresql_using="gin")
|
||||
|
||||
# Create index for chat_message
|
||||
op.create_index("chat_channel_idx", "chat_message", ["platform", "channel_id"])
|
||||
|
||||
# Create indexes for git_commit
|
||||
op.create_index(
|
||||
"git_files_idx", "git_commit", ["files_changed"], postgresql_using="gin"
|
||||
)
|
||||
op.create_index("git_date_idx", "git_commit", ["author_date"])
|
||||
|
||||
# Create index for photo
|
||||
op.create_index("photo_taken_idx", "photo", ["exif_taken_at"])
|
||||
|
||||
# Create indexes for rss_feeds
|
||||
op.create_index("rss_feeds_active_idx", "rss_feeds", ["active", "last_checked_at"])
|
||||
op.create_index("rss_feeds_tags_idx", "rss_feeds", ["tags"], postgresql_using="gin")
|
||||
|
||||
# Create indexes for email_accounts
|
||||
op.create_index(
|
||||
"email_accounts_address_idx", "email_accounts", ["email_address"], unique=True
|
||||
["repo_path", "kind", "number"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"email_accounts_active_idx", "email_accounts", ["active", "last_sync_at"]
|
||||
"gh_labels_idx", "github_item", ["labels"], unique=False, postgresql_using="gin"
|
||||
)
|
||||
op.create_index(
|
||||
"email_accounts_tags_idx", "email_accounts", ["tags"], postgresql_using="gin"
|
||||
"gh_repo_kind_idx", "github_item", ["repo_path", "kind"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"mail_message",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_id", sa.Text(), nullable=True),
|
||||
sa.Column("subject", sa.Text(), nullable=True),
|
||||
sa.Column("sender", sa.Text(), nullable=True),
|
||||
sa.Column("recipients", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("folder", sa.Text(), nullable=True),
|
||||
sa.Column("tsv", postgresql.TSVECTOR(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("message_id"),
|
||||
)
|
||||
|
||||
# Create indexes for github_item
|
||||
op.create_index("gh_repo_kind_idx", "github_item", ["repo_path", "kind"])
|
||||
op.create_index(
|
||||
"gh_issue_lookup_idx", "github_item", ["repo_path", "kind", "number"]
|
||||
"mail_recipients_idx",
|
||||
"mail_message",
|
||||
["recipients"],
|
||||
unique=False,
|
||||
postgresql_using="gin",
|
||||
)
|
||||
op.create_index("mail_sent_idx", "mail_message", ["sent_at"], unique=False)
|
||||
op.create_index(
|
||||
"mail_tsv_idx", "mail_message", ["tsv"], unique=False, postgresql_using="gin"
|
||||
)
|
||||
op.create_table(
|
||||
"misc_doc",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("path", sa.Text(), nullable=True),
|
||||
sa.Column("mime_type", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"photo",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("exif_taken_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("exif_lat", sa.Numeric(precision=9, scale=6), nullable=True),
|
||||
sa.Column("exif_lon", sa.Numeric(precision=9, scale=6), nullable=True),
|
||||
sa.Column("camera", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("photo_taken_idx", "photo", ["exif_taken_at"], unique=False)
|
||||
op.create_table(
|
||||
"email_attachment",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("mail_message_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["mail_message_id"], ["mail_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"email_attachment_message_idx",
|
||||
"email_attachment",
|
||||
["mail_message_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index("gh_labels_idx", "github_item", ["labels"], postgresql_using="gin")
|
||||
|
||||
# Create add_tags helper function
|
||||
op.execute("""
|
||||
CREATE OR REPLACE FUNCTION add_tags(p_source BIGINT, p_tags TEXT[])
|
||||
RETURNS VOID LANGUAGE SQL AS $$
|
||||
UPDATE source_item
|
||||
SET tags =
|
||||
(SELECT ARRAY(SELECT DISTINCT unnest(tags || p_tags)))
|
||||
WHERE id = p_source;
|
||||
$$;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index("gh_labels_idx", table_name="github_item")
|
||||
op.drop_index("gh_issue_lookup_idx", table_name="github_item")
|
||||
op.drop_index("gh_repo_kind_idx", table_name="github_item")
|
||||
op.drop_index("email_accounts_tags_idx", table_name="email_accounts")
|
||||
op.drop_index("email_accounts_active_idx", table_name="email_accounts")
|
||||
op.drop_index("email_accounts_address_idx", table_name="email_accounts")
|
||||
op.drop_index("rss_feeds_tags_idx", table_name="rss_feeds")
|
||||
op.drop_index("rss_feeds_active_idx", table_name="rss_feeds")
|
||||
op.drop_index("photo_taken_idx", table_name="photo")
|
||||
op.drop_index("git_date_idx", table_name="git_commit")
|
||||
op.drop_index("git_files_idx", table_name="git_commit")
|
||||
op.drop_index("chat_channel_idx", table_name="chat_message")
|
||||
op.drop_index("mail_tsv_idx", table_name="mail_message")
|
||||
op.drop_index("mail_recipients_idx", table_name="mail_message")
|
||||
op.drop_index("mail_sent_idx", table_name="mail_message")
|
||||
op.drop_index("email_attachment_message_idx", table_name="email_attachment")
|
||||
op.drop_index("email_attachment_filename_idx", table_name="email_attachment")
|
||||
op.drop_index("source_tags_idx", table_name="source_item")
|
||||
op.drop_index("source_status_idx", table_name="source_item")
|
||||
op.drop_index("source_modality_idx", table_name="source_item")
|
||||
|
||||
# Drop tables
|
||||
op.drop_table("email_attachment")
|
||||
op.drop_index("photo_taken_idx", table_name="photo")
|
||||
op.drop_table("photo")
|
||||
op.drop_table("misc_doc")
|
||||
op.drop_index("mail_tsv_idx", table_name="mail_message", postgresql_using="gin")
|
||||
op.drop_index("mail_sent_idx", table_name="mail_message")
|
||||
op.drop_index(
|
||||
"mail_recipients_idx", table_name="mail_message", postgresql_using="gin"
|
||||
)
|
||||
op.drop_table("mail_message")
|
||||
op.drop_index("gh_repo_kind_idx", table_name="github_item")
|
||||
op.drop_index("gh_labels_idx", table_name="github_item", postgresql_using="gin")
|
||||
op.drop_index("gh_issue_lookup_idx", table_name="github_item")
|
||||
op.drop_table("github_item")
|
||||
op.drop_index("git_files_idx", table_name="git_commit", postgresql_using="gin")
|
||||
op.drop_index("git_date_idx", table_name="git_commit")
|
||||
op.drop_table("git_commit")
|
||||
op.drop_index("chunk_source_idx", table_name="chunk")
|
||||
op.drop_table("chunk")
|
||||
op.drop_index("chat_channel_idx", table_name="chat_message")
|
||||
op.drop_table("chat_message")
|
||||
op.drop_table("book_doc")
|
||||
op.drop_table("blog_post")
|
||||
op.drop_table("github_item")
|
||||
op.drop_table("email_attachment")
|
||||
op.drop_table("mail_message")
|
||||
op.drop_index("source_tags_idx", table_name="source_item", postgresql_using="gin")
|
||||
op.drop_index("source_status_idx", table_name="source_item")
|
||||
op.drop_index("source_modality_idx", table_name="source_item")
|
||||
op.drop_table("source_item")
|
||||
op.drop_index("rss_feeds_tags_idx", table_name="rss_feeds", postgresql_using="gin")
|
||||
op.drop_index("rss_feeds_active_idx", table_name="rss_feeds")
|
||||
op.drop_table("rss_feeds")
|
||||
op.drop_index(
|
||||
"email_accounts_tags_idx", table_name="email_accounts", postgresql_using="gin"
|
||||
)
|
||||
op.drop_index("email_accounts_address_idx", table_name="email_accounts")
|
||||
op.drop_index("email_accounts_active_idx", table_name="email_accounts")
|
||||
op.drop_table("email_accounts")
|
||||
|
||||
# Drop triggers and functions
|
||||
op.execute("DROP TRIGGER IF EXISTS check_vector_ids ON source_item")
|
||||
op.execute("DROP FUNCTION IF EXISTS trg_vector_ids_not_empty()")
|
||||
op.execute("DROP FUNCTION IF EXISTS add_tags(BIGINT, TEXT[])")
|
||||
|
||||
# Drop enum type
|
||||
op.execute("DROP TYPE IF EXISTS gh_item_kind")
|
130
src/memory/common/chunker.py
Normal file
130
src/memory/common/chunker.py
Normal file
@ -0,0 +1,130 @@
|
||||
import logging
|
||||
from typing import Iterable, Any
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Chunking configuration
|
||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
||||
OVERLAP_TOKENS = 200 # Default overlap between chunks
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
Vector = list[float]
|
||||
Embedding = tuple[str, Vector, dict[str, Any]]
|
||||
|
||||
|
||||
# Regex for sentence splitting
|
||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
words = text.split()
|
||||
if not words:
|
||||
return
|
||||
|
||||
current = ""
|
||||
for word in words:
|
||||
new_chunk = f"{current} {word}".strip()
|
||||
if current and approx_token_count(new_chunk) > max_tokens:
|
||||
yield current
|
||||
current = word
|
||||
else:
|
||||
current = new_chunk
|
||||
if current: # Only yield non-empty final chunk
|
||||
yield current
|
||||
|
||||
|
||||
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
"""
|
||||
Yield text spans in priority order: paragraphs, sentences, words.
|
||||
Each span is guaranteed to be under max_tokens.
|
||||
|
||||
Args:
|
||||
text: The text to split
|
||||
max_tokens: Maximum tokens per chunk
|
||||
|
||||
Yields:
|
||||
Spans of text that fit within token limits
|
||||
"""
|
||||
# Return early for empty text
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
for paragraph in text.split("\n\n"):
|
||||
if not paragraph.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(paragraph) <= max_tokens:
|
||||
yield paragraph
|
||||
continue
|
||||
|
||||
for sentence in _SENT_SPLIT_RE.split(paragraph):
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(sentence) <= max_tokens:
|
||||
yield sentence
|
||||
continue
|
||||
|
||||
for chunk in yield_word_chunks(sentence, max_tokens):
|
||||
yield chunk
|
||||
|
||||
|
||||
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
|
||||
"""
|
||||
Split text into chunks respecting semantic boundaries while staying within token limits.
|
||||
|
||||
Args:
|
||||
text: The text to chunk
|
||||
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
|
||||
overlap: Number of tokens to overlap between chunks (default: 200)
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
if approx_token_count(text) <= max_tokens:
|
||||
yield text
|
||||
return
|
||||
|
||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
||||
current = ""
|
||||
|
||||
for span in yield_spans(text, max_tokens):
|
||||
current = f"{current} {span}".strip()
|
||||
if approx_token_count(current) < max_tokens:
|
||||
continue
|
||||
|
||||
if overlap <= 0:
|
||||
yield current
|
||||
current = ""
|
||||
continue
|
||||
|
||||
overlap_text = current[-overlap_chars:]
|
||||
clean_break = max(
|
||||
overlap_text.rfind(". "),
|
||||
overlap_text.rfind("! "),
|
||||
overlap_text.rfind("? ")
|
||||
)
|
||||
|
||||
if clean_break < 0:
|
||||
yield current
|
||||
current = ""
|
||||
continue
|
||||
|
||||
break_offset = -overlap_chars + clean_break + 1
|
||||
chunk = current[break_offset:].strip()
|
||||
yield current
|
||||
current = chunk
|
||||
|
||||
if current:
|
||||
yield current.strip()
|
@ -1,257 +1,351 @@
|
||||
"""
|
||||
Database models for the knowledge base system.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
from email.message import EmailMessage
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
Column, ForeignKey, Integer, BigInteger, Text, DateTime, Boolean,
|
||||
ARRAY, func, Numeric, CheckConstraint, Index
|
||||
ARRAY,
|
||||
UUID,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.parsers.email import parse_email_message
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def clean_filename(filename: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
"""Stores content chunks with their vector embeddings."""
|
||||
|
||||
__tablename__ = "chunk"
|
||||
|
||||
# The ID is also used as the vector ID in the vector database
|
||||
id = Column(
|
||||
UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4()
|
||||
)
|
||||
source_id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
file_path = Column(Text) # Path to content if stored as a file
|
||||
content = Column(Text) # Direct content storage
|
||||
embedding_model = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
vector = list[float] | None # the vector generated by the embedding model
|
||||
item_metadata = dict[str, Any] | None
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
CheckConstraint("(file_path IS NOT NULL) OR (content IS NOT NULL)"),
|
||||
Index("chunk_source_idx", "source_id"),
|
||||
)
|
||||
|
||||
@property
|
||||
def data(self) -> list[bytes | str | Image.Image]:
|
||||
if not self.file_path:
|
||||
return [self.content]
|
||||
|
||||
path = pathlib.Path(self.file_path)
|
||||
if self.file_path.endswith("*"):
|
||||
files = list(path.parent.glob(path.name))
|
||||
else:
|
||||
files = [path]
|
||||
|
||||
items = []
|
||||
for file_path in files:
|
||||
if file_path.suffix == ".png":
|
||||
items.append(Image.open(file_path))
|
||||
elif file_path.suffix == ".bin":
|
||||
items.append(file_path.read_bytes())
|
||||
else:
|
||||
items.append(file_path.read_text())
|
||||
return items
|
||||
|
||||
|
||||
class SourceItem(Base):
|
||||
__tablename__ = 'source_item'
|
||||
|
||||
"""Base class for all content in the system using SQLAlchemy's joined table inheritance."""
|
||||
|
||||
__tablename__ = "source_item"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
modality = Column(Text, nullable=False)
|
||||
sha256 = Column(BYTEA, nullable=False, unique=True)
|
||||
inserted_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default='{}')
|
||||
lang = Column(Text)
|
||||
model_hash = Column(Text)
|
||||
vector_ids = Column(ARRAY(Text), nullable=False, server_default='{}')
|
||||
embed_status = Column(Text, nullable=False, server_default='RAW')
|
||||
byte_length = Column(Integer)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
size = Column(Integer)
|
||||
mime_type = Column(Text)
|
||||
|
||||
mail_message = relationship("MailMessage", back_populates="source", uselist=False)
|
||||
attachments = relationship("EmailAttachment", back_populates="source", cascade="all, delete-orphan", uselist=False)
|
||||
|
||||
# Content is stored in the database if it's small enough and text
|
||||
content = Column(Text)
|
||||
# Otherwise the content is stored on disk
|
||||
filename = Column(Text, nullable=True)
|
||||
|
||||
# Chunks relationship
|
||||
embed_status = Column(Text, nullable=False, server_default="RAW")
|
||||
chunks = relationship("Chunk", backref="source", cascade="all, delete-orphan")
|
||||
|
||||
# Discriminator column for SQLAlchemy inheritance
|
||||
type = Column(String(50))
|
||||
|
||||
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "source_item"}
|
||||
|
||||
# Add table-level constraint and indexes
|
||||
__table_args__ = (
|
||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
Index('source_modality_idx', 'modality'),
|
||||
Index('source_status_idx', 'embed_status'),
|
||||
Index('source_tags_idx', 'tags', postgresql_using='gin'),
|
||||
Index("source_modality_idx", "modality"),
|
||||
Index("source_status_idx", "embed_status"),
|
||||
Index("source_tags_idx", "tags", postgresql_using="gin"),
|
||||
Index("source_filename_idx", "filename"),
|
||||
)
|
||||
|
||||
@property
|
||||
def vector_ids(self):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
class MailMessage(Base):
|
||||
__tablename__ = 'mail_message'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
|
||||
class MailMessage(SourceItem):
|
||||
__tablename__ = "mail_message"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
message_id = Column(Text, unique=True)
|
||||
subject = Column(Text)
|
||||
sender = Column(Text)
|
||||
recipients = Column(ARRAY(Text))
|
||||
sent_at = Column(DateTime(timezone=True))
|
||||
body_raw = Column(Text)
|
||||
folder = Column(Text)
|
||||
tsv = Column(TSVECTOR)
|
||||
|
||||
attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan")
|
||||
source = relationship("SourceItem", back_populates="mail_message")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if not kwargs.get("modality"):
|
||||
kwargs["modality"] = "email"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
attachments = relationship(
|
||||
"EmailAttachment",
|
||||
back_populates="mail_message",
|
||||
foreign_keys="EmailAttachment.mail_message_id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "mail_message",
|
||||
}
|
||||
|
||||
@property
|
||||
def attachments_path(self) -> Path:
|
||||
return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX')
|
||||
|
||||
clean_sender = clean_filename(self.sender)
|
||||
clean_folder = clean_filename(self.folder or "INBOX")
|
||||
return Path(settings.FILE_STORAGE_DIR) / clean_sender / clean_folder
|
||||
|
||||
def safe_filename(self, filename: str) -> Path:
|
||||
suffix = Path(filename).suffix
|
||||
name = clean_filename(filename.removesuffix(suffix)) + suffix
|
||||
path = self.attachments_path / name
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.source_id,
|
||||
"source_id": self.id,
|
||||
"message_id": self.message_id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
"recipients": self.recipients,
|
||||
"folder": self.folder,
|
||||
"tags": self.source.tags,
|
||||
"tags": self.tags,
|
||||
"date": self.sent_at and self.sent_at.isoformat() or None,
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def parsed_content(self) -> EmailMessage:
|
||||
return parse_email_message(self.content, self.message_id)
|
||||
|
||||
@property
|
||||
def body(self) -> str:
|
||||
return self.parsed_content["body"]
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('mail_sent_idx', 'sent_at'),
|
||||
Index('mail_recipients_idx', 'recipients', postgresql_using='gin'),
|
||||
Index('mail_tsv_idx', 'tsv', postgresql_using='gin'),
|
||||
Index("mail_sent_idx", "sent_at"),
|
||||
Index("mail_recipients_idx", "recipients", postgresql_using="gin"),
|
||||
Index("mail_tsv_idx", "tsv", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
|
||||
class EmailAttachment(Base):
|
||||
__tablename__ = 'email_attachment'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
mail_message_id = Column(BigInteger, ForeignKey('mail_message.id', ondelete='CASCADE'), nullable=False)
|
||||
filename = Column(Text, nullable=False)
|
||||
content_type = Column(Text)
|
||||
size = Column(Integer)
|
||||
content = Column(BYTEA) # For small files stored inline
|
||||
file_path = Column(Text) # For larger files stored on disk
|
||||
class EmailAttachment(SourceItem):
|
||||
__tablename__ = "email_attachment"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
mail_message_id = Column(
|
||||
BigInteger, ForeignKey("mail_message.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
mail_message = relationship("MailMessage", back_populates="attachments")
|
||||
source = relationship("SourceItem", back_populates="attachments")
|
||||
|
||||
|
||||
mail_message = relationship(
|
||||
"MailMessage", back_populates="attachments", foreign_keys=[mail_message_id]
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "email_attachment",
|
||||
}
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"filename": self.filename,
|
||||
"content_type": self.content_type,
|
||||
"content_type": self.mime_type,
|
||||
"size": self.size,
|
||||
"created_at": self.created_at and self.created_at.isoformat() or None,
|
||||
"mail_message_id": self.mail_message_id,
|
||||
"source_id": self.mail_message.source_id,
|
||||
"tags": self.mail_message.source.tags,
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('email_attachment_message_idx', 'mail_message_id'),
|
||||
Index('email_attachment_filename_idx', 'filename'),
|
||||
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
||||
|
||||
|
||||
class ChatMessage(SourceItem):
|
||||
__tablename__ = "chat_message"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = 'chat_message'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
platform = Column(Text)
|
||||
channel_id = Column(Text)
|
||||
author = Column(Text)
|
||||
sent_at = Column(DateTime(timezone=True))
|
||||
body_raw = Column(Text)
|
||||
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "chat_message",
|
||||
}
|
||||
|
||||
# Add index
|
||||
__table_args__ = (
|
||||
Index('chat_channel_idx', 'platform', 'channel_id'),
|
||||
__table_args__ = (Index("chat_channel_idx", "platform", "channel_id"),)
|
||||
|
||||
|
||||
class GitCommit(SourceItem):
|
||||
__tablename__ = "git_commit"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class GitCommit(Base):
|
||||
__tablename__ = 'git_commit'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
repo_path = Column(Text)
|
||||
commit_sha = Column(Text, unique=True)
|
||||
author_name = Column(Text)
|
||||
author_email = Column(Text)
|
||||
author_date = Column(DateTime(timezone=True))
|
||||
msg_raw = Column(Text)
|
||||
diff_summary = Column(Text)
|
||||
files_changed = Column(ARRAY(Text))
|
||||
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "git_commit",
|
||||
}
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('git_files_idx', 'files_changed', postgresql_using='gin'),
|
||||
Index('git_date_idx', 'author_date'),
|
||||
Index("git_files_idx", "files_changed", postgresql_using="gin"),
|
||||
Index("git_date_idx", "author_date"),
|
||||
)
|
||||
|
||||
|
||||
class Photo(Base):
|
||||
__tablename__ = 'photo'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
file_path = Column(Text)
|
||||
class Photo(SourceItem):
|
||||
__tablename__ = "photo"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
exif_taken_at = Column(DateTime(timezone=True))
|
||||
exif_lat = Column(Numeric(9, 6))
|
||||
exif_lon = Column(Numeric(9, 6))
|
||||
camera = Column(Text)
|
||||
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "photo",
|
||||
}
|
||||
|
||||
# Add index
|
||||
__table_args__ = (
|
||||
Index('photo_taken_idx', 'exif_taken_at'),
|
||||
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
|
||||
|
||||
|
||||
class BookDoc(SourceItem):
|
||||
__tablename__ = "book_doc"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class BookDoc(Base):
|
||||
__tablename__ = 'book_doc'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
title = Column(Text)
|
||||
author = Column(Text)
|
||||
chapter = Column(Text)
|
||||
published = Column(DateTime(timezone=True))
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "book_doc",
|
||||
}
|
||||
|
||||
class BlogPost(Base):
|
||||
__tablename__ = 'blog_post'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
|
||||
class BlogPost(SourceItem):
|
||||
__tablename__ = "blog_post"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
url = Column(Text, unique=True)
|
||||
title = Column(Text)
|
||||
published = Column(DateTime(timezone=True))
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "blog_post",
|
||||
}
|
||||
|
||||
class MiscDoc(Base):
|
||||
__tablename__ = 'misc_doc'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
|
||||
class MiscDoc(SourceItem):
|
||||
__tablename__ = "misc_doc"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
path = Column(Text)
|
||||
mime_type = Column(Text)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "misc_doc",
|
||||
}
|
||||
|
||||
class RssFeed(Base):
|
||||
__tablename__ = 'rss_feeds'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
url = Column(Text, nullable=False, unique=True)
|
||||
title = Column(Text)
|
||||
description = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default='{}')
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default='true')
|
||||
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('rss_feeds_active_idx', 'active', 'last_checked_at'),
|
||||
Index('rss_feeds_tags_idx', 'tags', postgresql_using='gin'),
|
||||
|
||||
class GithubItem(SourceItem):
|
||||
__tablename__ = "github_item"
|
||||
|
||||
id = Column(
|
||||
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class EmailAccount(Base):
|
||||
__tablename__ = 'email_accounts'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
name = Column(Text, nullable=False)
|
||||
email_address = Column(Text, nullable=False, unique=True)
|
||||
imap_server = Column(Text, nullable=False)
|
||||
imap_port = Column(Integer, nullable=False, server_default='993')
|
||||
username = Column(Text, nullable=False)
|
||||
password = Column(Text, nullable=False)
|
||||
use_ssl = Column(Boolean, nullable=False, server_default='true')
|
||||
folders = Column(ARRAY(Text), nullable=False, server_default='{}')
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default='{}')
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default='true')
|
||||
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('email_accounts_address_idx', 'email_address', unique=True),
|
||||
Index('email_accounts_active_idx', 'active', 'last_sync_at'),
|
||||
Index('email_accounts_tags_idx', 'tags', postgresql_using='gin'),
|
||||
)
|
||||
|
||||
|
||||
class GithubItem(Base):
|
||||
__tablename__ = 'github_item'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
|
||||
kind = Column(Text, nullable=False)
|
||||
repo_path = Column(Text, nullable=False)
|
||||
number = Column(Integer)
|
||||
@ -259,19 +353,76 @@ class GithubItem(Base):
|
||||
commit_sha = Column(Text)
|
||||
state = Column(Text)
|
||||
title = Column(Text)
|
||||
body_raw = Column(Text)
|
||||
labels = Column(ARRAY(Text))
|
||||
author = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
closed_at = Column(DateTime(timezone=True))
|
||||
merged_at = Column(DateTime(timezone=True))
|
||||
diff_summary = Column(Text)
|
||||
|
||||
|
||||
payload = Column(JSONB)
|
||||
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "github_item",
|
||||
}
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("kind IN ('issue', 'pr', 'comment', 'project_card')"),
|
||||
Index('gh_repo_kind_idx', 'repo_path', 'kind'),
|
||||
Index('gh_issue_lookup_idx', 'repo_path', 'kind', 'number'),
|
||||
Index('gh_labels_idx', 'labels', postgresql_using='gin'),
|
||||
)
|
||||
Index("gh_repo_kind_idx", "repo_path", "kind"),
|
||||
Index("gh_issue_lookup_idx", "repo_path", "kind", "number"),
|
||||
Index("gh_labels_idx", "labels", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
|
||||
class RssFeed(Base):
|
||||
__tablename__ = "rss_feeds"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
url = Column(Text, nullable=False, unique=True)
|
||||
title = Column(Text)
|
||||
description = Column(Text)
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("rss_feeds_active_idx", "active", "last_checked_at"),
|
||||
Index("rss_feeds_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
|
||||
class EmailAccount(Base):
|
||||
__tablename__ = "email_accounts"
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
name = Column(Text, nullable=False)
|
||||
email_address = Column(Text, nullable=False, unique=True)
|
||||
imap_server = Column(Text, nullable=False)
|
||||
imap_port = Column(Integer, nullable=False, server_default="993")
|
||||
username = Column(Text, nullable=False)
|
||||
password = Column(Text, nullable=False)
|
||||
use_ssl = Column(Boolean, nullable=False, server_default="true")
|
||||
folders = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default="true")
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("email_accounts_address_idx", "email_address", unique=True),
|
||||
Index("email_accounts_active_idx", "active", "last_sync_at"),
|
||||
Index("email_accounts_tags_idx", "tags", postgresql_using="gin"),
|
||||
)
|
||||
|
@ -1,9 +1,17 @@
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Literal, TypedDict, Iterable, Any
|
||||
import voyageai
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Iterable, Literal, TypedDict
|
||||
|
||||
import voyageai
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import extract, settings
|
||||
from memory.common.chunker import chunk_text
|
||||
from memory.common.db.models import Chunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Chunking configuration
|
||||
MAX_TOKENS = 32000 # VoyageAI max context window
|
||||
@ -27,16 +35,24 @@ DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
||||
"mail": {"dimension": 1024, "distance": "Cosine"},
|
||||
"chat": {"dimension": 1024, "distance": "Cosine"},
|
||||
"git": {"dimension": 1024, "distance": "Cosine"},
|
||||
"photo": {"dimension": 512, "distance": "Cosine"},
|
||||
"book": {"dimension": 1024, "distance": "Cosine"},
|
||||
"blog": {"dimension": 1024, "distance": "Cosine"},
|
||||
"text": {"dimension": 1024, "distance": "Cosine"},
|
||||
# Multimodal
|
||||
"photo": {"dimension": 1024, "distance": "Cosine"},
|
||||
"doc": {"dimension": 1024, "distance": "Cosine"},
|
||||
}
|
||||
|
||||
TYPES = {
|
||||
"doc": ["text/*"],
|
||||
"doc": ["application/pdf", "application/docx", "application/msword"],
|
||||
"text": ["text/*"],
|
||||
"blog": ["text/markdown", "text/html"],
|
||||
"photo": ["image/*"],
|
||||
"book": ["application/pdf", "application/epub+zip", "application/mobi", "application/x-mobipocket-ebook"],
|
||||
"book": [
|
||||
"application/epub+zip",
|
||||
"application/mobi",
|
||||
"application/x-mobipocket-ebook",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@ -52,143 +68,55 @@ def get_modality(mime_type: str) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Regex for sentence splitting
|
||||
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
|
||||
|
||||
|
||||
def approx_token_count(s: str) -> int:
|
||||
return len(s) // CHARS_PER_TOKEN
|
||||
|
||||
|
||||
def yield_word_chunks(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
words = text.split()
|
||||
if not words:
|
||||
return
|
||||
|
||||
current = ""
|
||||
for word in words:
|
||||
new_chunk = f"{current} {word}".strip()
|
||||
if current and approx_token_count(new_chunk) > max_tokens:
|
||||
yield current
|
||||
current = word
|
||||
else:
|
||||
current = new_chunk
|
||||
if current: # Only yield non-empty final chunk
|
||||
yield current
|
||||
|
||||
|
||||
def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
"""
|
||||
Yield text spans in priority order: paragraphs, sentences, words.
|
||||
Each span is guaranteed to be under max_tokens.
|
||||
|
||||
Args:
|
||||
text: The text to split
|
||||
max_tokens: Maximum tokens per chunk
|
||||
|
||||
Yields:
|
||||
Spans of text that fit within token limits
|
||||
"""
|
||||
# Return early for empty text
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
for paragraph in text.split("\n\n"):
|
||||
if not paragraph.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(paragraph) <= max_tokens:
|
||||
yield paragraph
|
||||
continue
|
||||
|
||||
for sentence in _SENT_SPLIT_RE.split(paragraph):
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
if approx_token_count(sentence) <= max_tokens:
|
||||
yield sentence
|
||||
continue
|
||||
|
||||
for chunk in yield_word_chunks(sentence, max_tokens):
|
||||
yield chunk
|
||||
|
||||
|
||||
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
|
||||
"""
|
||||
Split text into chunks respecting semantic boundaries while staying within token limits.
|
||||
|
||||
Args:
|
||||
text: The text to chunk
|
||||
max_tokens: Maximum tokens per chunk (default: VoyageAI max context)
|
||||
overlap: Number of tokens to overlap between chunks (default: 200)
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
if approx_token_count(text) <= max_tokens:
|
||||
yield text
|
||||
return
|
||||
|
||||
overlap_chars = overlap * CHARS_PER_TOKEN
|
||||
current = ""
|
||||
|
||||
for span in yield_spans(text, max_tokens):
|
||||
current = f"{current} {span}".strip()
|
||||
if approx_token_count(current) < max_tokens:
|
||||
continue
|
||||
|
||||
if overlap <= 0:
|
||||
yield current
|
||||
current = ""
|
||||
continue
|
||||
|
||||
overlap_text = current[-overlap_chars:]
|
||||
clean_break = max(
|
||||
overlap_text.rfind(". "),
|
||||
overlap_text.rfind("! "),
|
||||
overlap_text.rfind("? ")
|
||||
)
|
||||
|
||||
if clean_break < 0:
|
||||
yield current
|
||||
current = ""
|
||||
continue
|
||||
|
||||
break_offset = -overlap_chars + clean_break + 1
|
||||
chunk = current[break_offset:].strip()
|
||||
yield current
|
||||
current = chunk
|
||||
|
||||
if current:
|
||||
yield current.strip()
|
||||
|
||||
|
||||
def embed_chunks(chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||
def embed_chunks(
|
||||
chunks: list[extract.MulitmodalChunk], model: str = settings.TEXT_EMBEDDING_MODEL
|
||||
) -> list[Vector]:
|
||||
vo = voyageai.Client()
|
||||
return vo.embed(chunks, model=model).embeddings
|
||||
if model == settings.MIXED_EMBEDDING_MODEL:
|
||||
return vo.multimodal_embed(
|
||||
chunks, model=model, input_type="document"
|
||||
).embeddings
|
||||
return vo.embed(chunks, model=model, input_type="document").embeddings
|
||||
|
||||
|
||||
def embed_text(texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||
chunks = [c for text in texts for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
|
||||
return embed_chunks(chunks, model)
|
||||
def embed_text(
|
||||
texts: list[str], model: str = settings.TEXT_EMBEDDING_MODEL
|
||||
) -> list[Vector]:
|
||||
chunks = [
|
||||
c
|
||||
for text in texts
|
||||
for c in chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS)
|
||||
if c.strip()
|
||||
]
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
try:
|
||||
return embed_chunks(chunks, model)
|
||||
except voyageai.error.InvalidRequestError as e:
|
||||
logger.error(f"Error embedding text: {e}")
|
||||
logger.debug(f"Text: {texts}")
|
||||
raise
|
||||
|
||||
|
||||
def embed_file(file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL) -> list[Vector]:
|
||||
def embed_file(
|
||||
file_path: pathlib.Path, model: str = settings.TEXT_EMBEDDING_MODEL
|
||||
) -> list[Vector]:
|
||||
return embed_text([file_path.read_text()], model)
|
||||
|
||||
|
||||
def embed_mixed(items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL) -> list[Vector]:
|
||||
def embed_mixed(
|
||||
items: list[extract.MulitmodalChunk], model: str = settings.MIXED_EMBEDDING_MODEL
|
||||
) -> list[Vector]:
|
||||
def to_chunks(item: extract.MulitmodalChunk) -> Iterable[str]:
|
||||
if isinstance(item, str):
|
||||
return [c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()]
|
||||
return [
|
||||
c for c in chunk_text(item, MAX_TOKENS, OVERLAP_TOKENS) if c.strip()
|
||||
]
|
||||
return [item]
|
||||
|
||||
chunks = [c for item in items for c in to_chunks(item)]
|
||||
return embed_chunks(chunks, model)
|
||||
return embed_chunks([chunks], model)
|
||||
|
||||
|
||||
def embed_page(page: dict[str, Any]) -> list[Vector]:
|
||||
@ -198,17 +126,64 @@ def embed_page(page: dict[str, Any]) -> list[Vector]:
|
||||
return embed_mixed(contents, model=settings.MIXED_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
||||
if isinstance(item, str):
|
||||
filename = settings.FILE_STORAGE_DIR / f"{chunk_id}.txt"
|
||||
filename.write_text(item)
|
||||
elif isinstance(item, bytes):
|
||||
filename = settings.FILE_STORAGE_DIR / f"{chunk_id}.bin"
|
||||
filename.write_bytes(item)
|
||||
elif isinstance(item, Image.Image):
|
||||
filename = settings.FILE_STORAGE_DIR / f"{chunk_id}.png"
|
||||
item.save(filename)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(item)}")
|
||||
return filename
|
||||
|
||||
|
||||
def make_chunk(
|
||||
page: extract.Page, vector: Vector, metadata: dict[str, Any] = {}
|
||||
) -> Chunk:
|
||||
"""Create a Chunk object from a page and a vector.
|
||||
|
||||
This is quite complex, because we need to handle the case where the page is a single string,
|
||||
a single image, or a list of strings and images.
|
||||
"""
|
||||
chunk_id = str(uuid.uuid4())
|
||||
contents = page["contents"]
|
||||
content, filename = None, None
|
||||
if all(isinstance(c, str) for c in contents):
|
||||
content = "\n\n".join(contents)
|
||||
model = settings.TEXT_EMBEDDING_MODEL
|
||||
elif len(contents) == 1:
|
||||
filename = (write_to_file(chunk_id, contents[0]),)
|
||||
model = settings.MIXED_EMBEDDING_MODEL
|
||||
else:
|
||||
for i, item in enumerate(contents):
|
||||
write_to_file(f"{chunk_id}_{i}", item)
|
||||
model = settings.MIXED_EMBEDDING_MODEL
|
||||
filename = (settings.FILE_STORAGE_DIR / f"{chunk_id}_*",)
|
||||
|
||||
return Chunk(
|
||||
id=chunk_id,
|
||||
file_path=filename,
|
||||
content=content,
|
||||
embedding_model=model,
|
||||
vector=vector,
|
||||
item_metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def embed(
|
||||
mime_type: str,
|
||||
content: bytes | str | pathlib.Path,
|
||||
metadata: dict[str, Any] = {},
|
||||
) -> tuple[str, list[Embedding]]:
|
||||
modality = get_modality(mime_type)
|
||||
|
||||
pages = extract.extract_content(mime_type, content)
|
||||
vectors = [
|
||||
(str(uuid.uuid4()), vector, page.get("metadata", {}) | metadata)
|
||||
chunks = [
|
||||
make_chunk(page, vector, metadata)
|
||||
for page in pages
|
||||
for vector in embed_page(page)
|
||||
]
|
||||
return modality, vectors
|
||||
return modality, chunks
|
||||
|
@ -7,6 +7,7 @@ from PIL import Image
|
||||
from typing import Any, TypedDict, Generator
|
||||
|
||||
|
||||
|
||||
MulitmodalChunk = Image.Image | str
|
||||
class Page(TypedDict):
|
||||
contents: list[MulitmodalChunk]
|
||||
@ -33,14 +34,16 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
|
||||
def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
with as_file(content) as file_path:
|
||||
with pymupdf.open(file_path) as pdf:
|
||||
return [{
|
||||
"contents": page_to_image(page),
|
||||
"metadata": {
|
||||
"page": page.number,
|
||||
"width": page.rect.width,
|
||||
"height": page.rect.height,
|
||||
}
|
||||
} for page in pdf.pages()]
|
||||
return [
|
||||
{
|
||||
"contents": page_to_image(page),
|
||||
"metadata": {
|
||||
"page": page.number,
|
||||
"width": page.rect.width,
|
||||
"height": page.rect.height,
|
||||
}
|
||||
} for page in pdf.pages()
|
||||
]
|
||||
|
||||
|
||||
def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
@ -50,7 +53,7 @@ def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
image = Image.open(io.BytesIO(content))
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
return [{"contents": image, "metadata": {}}]
|
||||
return [{"contents": [image], "metadata": {}}]
|
||||
|
||||
|
||||
def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
|
177
src/memory/common/parsers/email.py
Normal file
177
src/memory/common/parsers/email.py
Normal file
@ -0,0 +1,177 @@
|
||||
import email
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import TypedDict, Literal
|
||||
import pathlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Attachment(TypedDict):
|
||||
filename: str
|
||||
content_type: str
|
||||
size: int
|
||||
content: bytes
|
||||
path: pathlib.Path
|
||||
|
||||
|
||||
class EmailMessage(TypedDict):
|
||||
message_id: str
|
||||
subject: str
|
||||
sender: str
|
||||
recipients: list[str]
|
||||
sent_at: datetime | None
|
||||
body: str
|
||||
attachments: list[Attachment]
|
||||
|
||||
|
||||
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
|
||||
|
||||
|
||||
def extract_recipients(msg: email.message.Message) -> list[str]:
|
||||
"""
|
||||
Extract email recipients from message headers.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
List of recipient email addresses
|
||||
"""
|
||||
return [
|
||||
recipient
|
||||
for field in ["To", "Cc", "Bcc"]
|
||||
if (field_value := msg.get(field, ""))
|
||||
for r in field_value.split(",")
|
||||
if (recipient := r.strip())
|
||||
]
|
||||
|
||||
|
||||
def extract_date(msg: email.message.Message) -> datetime | None:
|
||||
"""
|
||||
Parse date from email header.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
Parsed datetime or None if parsing failed
|
||||
"""
|
||||
if date_str := msg.get("Date"):
|
||||
try:
|
||||
return parsedate_to_datetime(date_str)
|
||||
except Exception:
|
||||
logger.warning(f"Could not parse date: {date_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_body(msg: email.message.Message) -> str:
|
||||
"""
|
||||
Extract plain text body from email message.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
Plain text body content
|
||||
"""
|
||||
body = ""
|
||||
|
||||
if not msg.is_multipart():
|
||||
try:
|
||||
return msg.get_payload(decode=True).decode(errors='replace')
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding message body: {str(e)}")
|
||||
return ""
|
||||
|
||||
for part in msg.walk():
|
||||
content_type = part.get_content_type()
|
||||
content_disposition = str(part.get("Content-Disposition", ""))
|
||||
|
||||
if content_type == "text/plain" and "attachment" not in content_disposition:
|
||||
try:
|
||||
body += part.get_payload(decode=True).decode(errors='replace') + "\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding message part: {str(e)}")
|
||||
return body
|
||||
|
||||
|
||||
def extract_attachments(msg: email.message.Message) -> list[Attachment]:
|
||||
"""
|
||||
Extract attachment metadata and content from email.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
List of attachment dictionaries with metadata and content
|
||||
"""
|
||||
if not msg.is_multipart():
|
||||
return []
|
||||
|
||||
attachments = []
|
||||
for part in msg.walk():
|
||||
content_disposition = part.get("Content-Disposition", "")
|
||||
if "attachment" not in content_disposition:
|
||||
continue
|
||||
|
||||
if filename := part.get_filename():
|
||||
try:
|
||||
content = part.get_payload(decode=True)
|
||||
attachments.append({
|
||||
"filename": filename,
|
||||
"content_type": part.get_content_type(),
|
||||
"size": len(content),
|
||||
"content": content
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting attachment content for {filename}: {str(e)}")
|
||||
|
||||
return attachments
|
||||
|
||||
|
||||
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
|
||||
"""
|
||||
Compute a SHA-256 hash of message content.
|
||||
|
||||
Args:
|
||||
msg_id: Message ID
|
||||
subject: Email subject
|
||||
sender: Sender email
|
||||
body: Message body
|
||||
|
||||
Returns:
|
||||
SHA-256 hash as bytes
|
||||
"""
|
||||
hash_content = (msg_id + subject + sender + body).encode()
|
||||
return hashlib.sha256(hash_content).digest()
|
||||
|
||||
|
||||
def parse_email_message(raw_email: str, message_id: str) -> EmailMessage:
|
||||
"""
|
||||
Parse raw email into structured data.
|
||||
|
||||
Args:
|
||||
raw_email: Raw email content as string
|
||||
|
||||
Returns:
|
||||
Dict with parsed email data
|
||||
"""
|
||||
msg = email.message_from_string(raw_email)
|
||||
message_id = msg.get("Message-ID") or f"generated-{message_id}"
|
||||
subject = msg.get("Subject", "")
|
||||
from_ = msg.get("From", "")
|
||||
body = extract_body(msg)
|
||||
|
||||
return EmailMessage(
|
||||
message_id=message_id,
|
||||
subject=subject,
|
||||
sender=from_,
|
||||
recipients=extract_recipients(msg),
|
||||
sent_at=extract_date(msg),
|
||||
body=body,
|
||||
attachments=extract_attachments(msg),
|
||||
hash=compute_message_hash(message_id, subject, from_, body)
|
||||
)
|
@ -84,7 +84,9 @@ def initialize_collections(client: qdrant_client.QdrantClient, collections: dict
|
||||
if collections is None:
|
||||
collections = DEFAULT_COLLECTIONS
|
||||
|
||||
logger.info(f"Initializing collections:")
|
||||
for name, params in collections.items():
|
||||
logger.info(f" - {name}")
|
||||
ensure_collection_exists(
|
||||
client,
|
||||
collection_name=name,
|
||||
|
@ -39,8 +39,8 @@ QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
||||
|
||||
|
||||
# Worker settings
|
||||
BEAT_LOOP_INTERVAL = int(os.getenv("BEAT_LOOP_INTERVAL", 3600))
|
||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 600))
|
||||
EMAIL_SYNC_INTERVAL = int(os.getenv("EMAIL_SYNC_INTERVAL", 3600))
|
||||
EMAIL_SYNC_INTERVAL = 60
|
||||
|
||||
|
||||
# Embedding settings
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
from celery import Celery
|
||||
from memory.common import settings
|
||||
|
||||
|
||||
def rabbit_url() -> str:
|
||||
user = os.getenv("RABBITMQ_USER", "guest")
|
||||
password = os.getenv("RABBITMQ_PASSWORD", "guest")
|
||||
@ -32,3 +33,10 @@ app.conf.update(
|
||||
"memory.workers.tasks.docs.*": {"queue": "docs"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@app.on_after_configure.connect
|
||||
def ensure_qdrant_initialised(sender, **_):
|
||||
from memory.common import qdrant
|
||||
qdrant.setup_qdrant()
|
@ -1,200 +1,80 @@
|
||||
import email
|
||||
import hashlib
|
||||
import imaplib
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
import base64
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Generator, Callable, TypedDict, Literal
|
||||
from typing import Generator, Callable
|
||||
import pathlib
|
||||
from sqlalchemy.orm import Session
|
||||
from collections import defaultdict
|
||||
from memory.common import settings, embedding
|
||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||
from memory.common import qdrant
|
||||
from memory.common import settings, embedding, qdrant
|
||||
from memory.common.db.models import (
|
||||
EmailAccount,
|
||||
MailMessage,
|
||||
SourceItem,
|
||||
EmailAttachment,
|
||||
)
|
||||
from memory.common.parsers.email import (
|
||||
Attachment,
|
||||
parse_email_message,
|
||||
RawEmailResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Attachment(TypedDict):
|
||||
filename: str
|
||||
content_type: str
|
||||
size: int
|
||||
content: bytes
|
||||
path: pathlib.Path
|
||||
|
||||
|
||||
class EmailMessage(TypedDict):
|
||||
message_id: str
|
||||
subject: str
|
||||
sender: str
|
||||
recipients: list[str]
|
||||
sent_at: datetime | None
|
||||
body: str
|
||||
attachments: list[Attachment]
|
||||
|
||||
|
||||
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
|
||||
|
||||
|
||||
def extract_recipients(msg: email.message.Message) -> list[str]:
|
||||
"""
|
||||
Extract email recipients from message headers.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
List of recipient email addresses
|
||||
"""
|
||||
return [
|
||||
recipient
|
||||
for field in ["To", "Cc", "Bcc"]
|
||||
if (field_value := msg.get(field, ""))
|
||||
for r in field_value.split(",")
|
||||
if (recipient := r.strip())
|
||||
]
|
||||
|
||||
|
||||
def extract_date(msg: email.message.Message) -> datetime | None:
|
||||
"""
|
||||
Parse date from email header.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
Parsed datetime or None if parsing failed
|
||||
"""
|
||||
if date_str := msg.get("Date"):
|
||||
try:
|
||||
return parsedate_to_datetime(date_str)
|
||||
except Exception:
|
||||
logger.warning(f"Could not parse date: {date_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_body(msg: email.message.Message) -> str:
|
||||
"""
|
||||
Extract plain text body from email message.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
Plain text body content
|
||||
"""
|
||||
body = ""
|
||||
|
||||
if not msg.is_multipart():
|
||||
try:
|
||||
return msg.get_payload(decode=True).decode(errors='replace')
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding message body: {str(e)}")
|
||||
return ""
|
||||
|
||||
for part in msg.walk():
|
||||
content_type = part.get_content_type()
|
||||
content_disposition = str(part.get("Content-Disposition", ""))
|
||||
|
||||
if content_type == "text/plain" and "attachment" not in content_disposition:
|
||||
try:
|
||||
body += part.get_payload(decode=True).decode(errors='replace') + "\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding message part: {str(e)}")
|
||||
return body
|
||||
|
||||
|
||||
def extract_attachments(msg: email.message.Message) -> list[Attachment]:
|
||||
"""
|
||||
Extract attachment metadata and content from email.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
List of attachment dictionaries with metadata and content
|
||||
"""
|
||||
if not msg.is_multipart():
|
||||
return []
|
||||
|
||||
attachments = []
|
||||
for part in msg.walk():
|
||||
content_disposition = part.get("Content-Disposition", "")
|
||||
if "attachment" not in content_disposition:
|
||||
continue
|
||||
|
||||
if filename := part.get_filename():
|
||||
try:
|
||||
content = part.get_payload(decode=True)
|
||||
attachments.append({
|
||||
"filename": filename,
|
||||
"content_type": part.get_content_type(),
|
||||
"size": len(content),
|
||||
"content": content
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting attachment content for {filename}: {str(e)}")
|
||||
|
||||
return attachments
|
||||
|
||||
|
||||
def process_attachment(attachment: Attachment, message: MailMessage) -> EmailAttachment | None:
|
||||
def process_attachment(
|
||||
attachment: Attachment, message: MailMessage
|
||||
) -> EmailAttachment | None:
|
||||
"""Process an attachment, storing large files on disk and returning metadata.
|
||||
|
||||
|
||||
Args:
|
||||
attachment: Attachment dictionary with metadata and content
|
||||
message_id: Email message ID to use in file path generation
|
||||
|
||||
|
||||
Returns:
|
||||
Processed attachment dictionary with appropriate metadata
|
||||
"""
|
||||
content, file_path = None, None
|
||||
if not (real_content := attachment.get("content")):
|
||||
"No content, so just save the metadata"
|
||||
elif attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE:
|
||||
content = base64.b64encode(real_content)
|
||||
elif attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE and attachment[
|
||||
"content_type"
|
||||
].startswith("text/"):
|
||||
content = real_content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
safe_filename = re.sub(r'[/\\]', '_', attachment["filename"])
|
||||
user_dir = message.attachments_path
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = user_dir / safe_filename
|
||||
file_path = message.safe_filename(attachment["filename"])
|
||||
try:
|
||||
file_path.write_bytes(real_content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}")
|
||||
logger.error(f"Failed to save attachment {file_path} to disk: {str(e)}")
|
||||
return None
|
||||
|
||||
source_item = SourceItem(
|
||||
modality=embedding.get_modality(attachment["content_type"]),
|
||||
sha256=hashlib.sha256(real_content if real_content else str(attachment).encode()).digest(),
|
||||
tags=message.source.tags,
|
||||
byte_length=attachment["size"],
|
||||
mime_type=attachment["content_type"],
|
||||
)
|
||||
|
||||
return EmailAttachment(
|
||||
source=source_item,
|
||||
filename=attachment["filename"],
|
||||
modality=embedding.get_modality(attachment["content_type"]),
|
||||
sha256=hashlib.sha256(
|
||||
real_content if real_content else str(attachment).encode()
|
||||
).digest(),
|
||||
tags=message.tags,
|
||||
size=attachment["size"],
|
||||
mime_type=attachment["content_type"],
|
||||
mail_message=message,
|
||||
content_type=attachment.get("content_type"),
|
||||
size=attachment.get("size"),
|
||||
content=content,
|
||||
file_path=file_path and str(file_path),
|
||||
filename=file_path and str(file_path),
|
||||
)
|
||||
|
||||
|
||||
def process_attachments(attachments: list[Attachment], message: MailMessage) -> list[EmailAttachment]:
|
||||
def process_attachments(
|
||||
attachments: list[Attachment], message: MailMessage
|
||||
) -> list[EmailAttachment]:
|
||||
"""
|
||||
Process email attachments, storing large files on disk and returning metadata.
|
||||
|
||||
|
||||
Args:
|
||||
attachments: List of attachment dictionaries with metadata and content
|
||||
message_id: Email message ID to use in file path generation
|
||||
|
||||
|
||||
Returns:
|
||||
List of processed attachment dictionaries with appropriate metadata
|
||||
"""
|
||||
@ -203,150 +83,111 @@ def process_attachments(attachments: list[Attachment], message: MailMessage) ->
|
||||
|
||||
return [
|
||||
attachment
|
||||
for a in attachments if (attachment := process_attachment(a, message))
|
||||
for a in attachments
|
||||
if (attachment := process_attachment(a, message))
|
||||
]
|
||||
|
||||
|
||||
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
|
||||
"""
|
||||
Compute a SHA-256 hash of message content.
|
||||
|
||||
Args:
|
||||
msg_id: Message ID
|
||||
subject: Email subject
|
||||
sender: Sender email
|
||||
body: Message body
|
||||
|
||||
Returns:
|
||||
SHA-256 hash as bytes
|
||||
"""
|
||||
hash_content = (msg_id + subject + sender + body).encode()
|
||||
return hashlib.sha256(hash_content).digest()
|
||||
|
||||
|
||||
def parse_email_message(raw_email: str) -> EmailMessage:
|
||||
"""
|
||||
Parse raw email into structured data.
|
||||
|
||||
Args:
|
||||
raw_email: Raw email content as string
|
||||
|
||||
Returns:
|
||||
Dict with parsed email data
|
||||
"""
|
||||
msg = email.message_from_string(raw_email)
|
||||
|
||||
return EmailMessage(
|
||||
message_id=msg.get("Message-ID", ""),
|
||||
subject=msg.get("Subject", ""),
|
||||
sender=msg.get("From", ""),
|
||||
recipients=extract_recipients(msg),
|
||||
sent_at=extract_date(msg),
|
||||
body=extract_body(msg),
|
||||
attachments=extract_attachments(msg)
|
||||
)
|
||||
|
||||
|
||||
def create_source_item(
|
||||
db_session: Session,
|
||||
message_hash: bytes,
|
||||
account_tags: list[str],
|
||||
raw_size: int,
|
||||
modality: str = "mail",
|
||||
mime_type: str = "message/rfc822",
|
||||
) -> SourceItem:
|
||||
"""
|
||||
Create a new source item record.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
message_hash: SHA-256 hash of message
|
||||
account_tags: Tags from the email account
|
||||
raw_size: Size of raw email in bytes
|
||||
|
||||
Returns:
|
||||
Newly created SourceItem
|
||||
"""
|
||||
source_item = SourceItem(
|
||||
modality=modality,
|
||||
sha256=message_hash,
|
||||
tags=account_tags,
|
||||
byte_length=raw_size,
|
||||
mime_type=mime_type,
|
||||
embed_status="RAW"
|
||||
)
|
||||
db_session.add(source_item)
|
||||
db_session.flush()
|
||||
return source_item
|
||||
|
||||
|
||||
def create_mail_message(
|
||||
db_session: Session,
|
||||
source_item: SourceItem,
|
||||
parsed_email: EmailMessage,
|
||||
tags: list[str],
|
||||
folder: str,
|
||||
raw_email: str,
|
||||
message_id: str,
|
||||
) -> MailMessage:
|
||||
"""
|
||||
Create a new mail message record and associated attachments.
|
||||
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
source_id: ID of the SourceItem
|
||||
parsed_email: Parsed email data
|
||||
folder: IMAP folder name
|
||||
|
||||
|
||||
Returns:
|
||||
Newly created MailMessage
|
||||
"""
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
mail_message = MailMessage(
|
||||
source=source_item,
|
||||
modality="mail",
|
||||
sha256=parsed_email["hash"],
|
||||
tags=tags,
|
||||
size=len(raw_email),
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
message_id=parsed_email["message_id"],
|
||||
subject=parsed_email["subject"],
|
||||
sender=parsed_email["sender"],
|
||||
recipients=parsed_email["recipients"],
|
||||
sent_at=parsed_email["sent_at"],
|
||||
body_raw=parsed_email["body"],
|
||||
content=raw_email,
|
||||
folder=folder,
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
if parsed_email["attachments"]:
|
||||
processed_attachments = process_attachments(parsed_email["attachments"], mail_message)
|
||||
db_session.add_all(processed_attachments)
|
||||
|
||||
mail_message.attachments = process_attachments(
|
||||
parsed_email["attachments"], mail_message
|
||||
)
|
||||
|
||||
db_session.add(mail_message)
|
||||
|
||||
return mail_message
|
||||
|
||||
|
||||
def check_message_exists(db_session: Session, message_id: str, message_hash: bytes) -> bool:
|
||||
def does_message_exist(
|
||||
db_session: Session, message_id: str, message_hash: bytes
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a message already exists in the database.
|
||||
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
message_id: Email message ID
|
||||
message_hash: SHA-256 hash of message
|
||||
|
||||
|
||||
Returns:
|
||||
True if message exists, False otherwise
|
||||
"""
|
||||
# Check by message_id first (faster)
|
||||
if message_id:
|
||||
mail_message = db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first()
|
||||
mail_message = (
|
||||
db_session.query(MailMessage)
|
||||
.filter(MailMessage.message_id == message_id)
|
||||
.first()
|
||||
)
|
||||
if mail_message is not None:
|
||||
return True
|
||||
|
||||
|
||||
# Then check by message_hash
|
||||
source_item = db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first()
|
||||
source_item = (
|
||||
db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first()
|
||||
)
|
||||
return source_item is not None
|
||||
|
||||
|
||||
def check_message_exists(
|
||||
db: Session, account_id: int, message_id: str, raw_email: str
|
||||
) -> bool:
|
||||
account = db.query(EmailAccount).get(account_id)
|
||||
if not account:
|
||||
logger.error(f"Account {account_id} not found")
|
||||
return None
|
||||
|
||||
parsed_email = parse_email_message(raw_email, message_id)
|
||||
|
||||
# Use server-provided message ID if missing
|
||||
if not parsed_email["message_id"]:
|
||||
parsed_email["message_id"] = f"generated-{message_id}"
|
||||
|
||||
return does_message_exist(db, parsed_email["message_id"], parsed_email["hash"])
|
||||
|
||||
|
||||
def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
|
||||
"""
|
||||
Extract the UID and raw email data from the message data.
|
||||
"""
|
||||
uid_pattern = re.compile(r'UID (\d+)')
|
||||
uid_match = uid_pattern.search(msg_data[0][0].decode('utf-8', errors='replace'))
|
||||
uid_pattern = re.compile(r"UID (\d+)")
|
||||
uid_match = uid_pattern.search(msg_data[0][0].decode("utf-8", errors="replace"))
|
||||
uid = uid_match.group(1) if uid_match else None
|
||||
raw_email = msg_data[0][1]
|
||||
return uid, raw_email
|
||||
@ -354,11 +195,11 @@ def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
|
||||
|
||||
def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
|
||||
try:
|
||||
status, msg_data = conn.fetch(uid, '(UID RFC822)')
|
||||
if status != 'OK' or not msg_data or not msg_data[0]:
|
||||
status, msg_data = conn.fetch(uid, "(UID RFC822)")
|
||||
if status != "OK" or not msg_data or not msg_data[0]:
|
||||
logger.error(f"Error fetching message {uid}")
|
||||
return None
|
||||
|
||||
|
||||
return extract_email_uid(msg_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message {uid}: {str(e)}")
|
||||
@ -368,38 +209,38 @@ def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
|
||||
def fetch_email_since(
|
||||
conn: imaplib.IMAP4_SSL,
|
||||
folder: str,
|
||||
since_date: datetime = datetime(1970, 1, 1)
|
||||
since_date: datetime = datetime(1970, 1, 1),
|
||||
) -> list[RawEmailResponse]:
|
||||
"""
|
||||
Fetch emails from a folder since a given date.
|
||||
|
||||
Fetch emails from a folder since a given date and time.
|
||||
|
||||
Args:
|
||||
conn: IMAP connection
|
||||
folder: Folder name to select
|
||||
since_date: Fetch emails since this date
|
||||
|
||||
since_date: Fetch emails since this date and time
|
||||
|
||||
Returns:
|
||||
List of tuples with (uid, raw_email)
|
||||
"""
|
||||
try:
|
||||
status, counts = conn.select(folder)
|
||||
if status != 'OK':
|
||||
if status != "OK":
|
||||
logger.error(f"Error selecting folder {folder}: {counts}")
|
||||
return []
|
||||
|
||||
|
||||
date_str = since_date.strftime("%d-%b-%Y")
|
||||
|
||||
|
||||
status, data = conn.search(None, f'(SINCE "{date_str}")')
|
||||
if status != 'OK':
|
||||
if status != "OK":
|
||||
logger.error(f"Error searching folder {folder}: {data}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fetch_email_since for folder {folder}: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
if not data or not data[0]:
|
||||
return []
|
||||
|
||||
|
||||
return [email for uid in data[0].split() if (email := fetch_email(conn, uid))]
|
||||
|
||||
|
||||
@ -412,13 +253,14 @@ def process_folder(
|
||||
) -> dict:
|
||||
"""
|
||||
Process a single folder from an email account.
|
||||
|
||||
|
||||
Args:
|
||||
conn: Active IMAP connection
|
||||
folder: Folder name to process
|
||||
account: Email account configuration
|
||||
since_date: Only fetch messages newer than this date
|
||||
|
||||
processor: Function to process each message
|
||||
|
||||
Returns:
|
||||
Stats dictionary for the folder
|
||||
"""
|
||||
@ -427,21 +269,21 @@ def process_folder(
|
||||
|
||||
try:
|
||||
emails = fetch_email_since(conn, folder, since_date)
|
||||
|
||||
|
||||
for uid, raw_email in emails:
|
||||
try:
|
||||
task = processor(
|
||||
account_id=account.id,
|
||||
message_id=uid,
|
||||
folder=folder,
|
||||
raw_email=raw_email.decode('utf-8', errors='replace')
|
||||
raw_email=raw_email.decode("utf-8", errors="replace"),
|
||||
)
|
||||
if task:
|
||||
new_messages += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error queuing message {uid}: {str(e)}")
|
||||
errors += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing folder {folder}: {str(e)}")
|
||||
errors += 1
|
||||
@ -449,16 +291,13 @@ def process_folder(
|
||||
return {
|
||||
"messages_found": len(emails),
|
||||
"new_messages": new_messages,
|
||||
"errors": errors
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]:
|
||||
conn = imaplib.IMAP4_SSL(
|
||||
host=account.imap_server,
|
||||
port=account.imap_port
|
||||
)
|
||||
conn = imaplib.IMAP4_SSL(host=account.imap_server, port=account.imap_port)
|
||||
try:
|
||||
conn.login(account.username, account.password)
|
||||
yield conn
|
||||
@ -470,43 +309,56 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
||||
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
|
||||
|
||||
|
||||
def vectorize_email(email: MailMessage) -> list[float]:
|
||||
def vectorize_email(email: MailMessage):
|
||||
qdrant_client = qdrant.get_qdrant_client()
|
||||
|
||||
_, chunks = embedding.embed(
|
||||
"text/plain", email.body_raw, metadata=email.as_payload(),
|
||||
"text/plain",
|
||||
email.body,
|
||||
metadata=email.as_payload(),
|
||||
)
|
||||
vector_ids, vectors, metadata = zip(*chunks)
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name="mail",
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=metadata,
|
||||
)
|
||||
vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
|
||||
email.chunks = chunks
|
||||
if chunks:
|
||||
vector_ids = [c.id for c in chunks]
|
||||
vectors = [c.vector for c in chunks]
|
||||
metadata = [c.item_metadata for c in chunks]
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name="mail",
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=metadata,
|
||||
)
|
||||
|
||||
embeds = defaultdict(list)
|
||||
for attachment in email.attachments:
|
||||
if attachment.file_path:
|
||||
content = pathlib.Path(attachment.file_path).read_bytes()
|
||||
if attachment.filename:
|
||||
content = pathlib.Path(attachment.filename).read_bytes()
|
||||
else:
|
||||
content = attachment.content
|
||||
collection, chunks = embedding.embed(attachment.content_type, content, metadata=attachment.as_payload())
|
||||
ids, vectors, metadata = zip(*chunks)
|
||||
attachment.source.vector_ids = ids
|
||||
collection, chunks = embedding.embed(
|
||||
attachment.mime_type, content, metadata=attachment.as_payload()
|
||||
)
|
||||
if not chunks:
|
||||
continue
|
||||
|
||||
attachment.chunks = chunks
|
||||
embeds[collection].extend(chunks)
|
||||
|
||||
for collection, chunks in embeds.items():
|
||||
ids, vectors, payloads = zip(*chunks)
|
||||
ids = [c.id for c in chunks]
|
||||
vectors = [c.vector for c in chunks]
|
||||
metadata = [c.item_metadata for c in chunks]
|
||||
qdrant.upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name=collection,
|
||||
ids=ids,
|
||||
vectors=vectors,
|
||||
payloads=payloads,
|
||||
payloads=metadata,
|
||||
)
|
||||
vector_ids.extend([f"{collection}/{vector_id}" for vector_id in ids])
|
||||
|
||||
email.embed_status = "STORED"
|
||||
for attachment in email.attachments:
|
||||
attachment.embed_status = "STORED"
|
||||
|
||||
logger.info(f"Stored embedding for message {email.message_id}")
|
||||
return vector_ids
|
||||
|
@ -1,14 +1,11 @@
|
||||
from celery.schedules import schedule
|
||||
from memory.workers.tasks.email import SYNC_ALL_ACCOUNTS
|
||||
from memory.common import settings
|
||||
|
||||
from memory.workers.celery_app import app
|
||||
from memory.common import settings
|
||||
|
||||
|
||||
@app.on_after_configure.connect
|
||||
def register_mail_schedules(sender, **_):
|
||||
sender.add_periodic_task(
|
||||
schedule=schedule(settings.EMAIL_SYNC_INTERVAL),
|
||||
sig=app.signature(SYNC_ALL_ACCOUNTS),
|
||||
name="sync-mail-all",
|
||||
options={"queue": "email"},
|
||||
)
|
||||
app.conf.beat_schedule = {
|
||||
'sync-mail-all': {
|
||||
'task': 'memory.workers.tasks.email.sync_all_accounts',
|
||||
'schedule': settings.EMAIL_SYNC_INTERVAL,
|
||||
},
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
"""
|
||||
Import sub-modules so Celery can register their @app.task decorators.
|
||||
"""
|
||||
from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa
|
||||
from memory.workers.tasks import docs, email # noqa
|
||||
from memory.workers.tasks.email import SYNC_ACCOUNT, SYNC_ALL_ACCOUNTS, PROCESS_EMAIL
|
||||
|
||||
|
||||
__all__ = ["text", "photo", "ocr", "git", "rss", "docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"]
|
||||
__all__ = ["docs", "email", "SYNC_ACCOUNT", "SYNC_ALL_ACCOUNTS", "PROCESS_EMAIL"]
|
@ -6,11 +6,8 @@ from memory.common.db.models import EmailAccount
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.email import (
|
||||
check_message_exists,
|
||||
compute_message_hash,
|
||||
create_mail_message,
|
||||
create_source_item,
|
||||
imap_connection,
|
||||
parse_email_message,
|
||||
process_folder,
|
||||
vectorize_email,
|
||||
)
|
||||
@ -25,17 +22,20 @@ SYNC_ALL_ACCOUNTS = "memory.workers.tasks.email.sync_all_accounts"
|
||||
|
||||
@app.task(name=PROCESS_EMAIL)
|
||||
def process_message(
|
||||
account_id: int, message_id: str, folder: str, raw_email: str,
|
||||
account_id: int,
|
||||
message_id: str,
|
||||
folder: str,
|
||||
raw_email: str,
|
||||
) -> int | None:
|
||||
"""
|
||||
Process a single email message and store it in the database.
|
||||
|
||||
|
||||
Args:
|
||||
account_id: ID of the EmailAccount
|
||||
message_id: UID of the message on the server
|
||||
folder: Folder name where the message is stored
|
||||
raw_email: Raw email content as string
|
||||
|
||||
|
||||
Returns:
|
||||
source_id if successful, None otherwise
|
||||
"""
|
||||
@ -45,88 +45,86 @@ def process_message(
|
||||
return None
|
||||
|
||||
with make_session() as db:
|
||||
if check_message_exists(db, account_id, message_id, raw_email):
|
||||
logger.debug(f"Message {message_id} already exists in database")
|
||||
return None
|
||||
|
||||
account = db.query(EmailAccount).get(account_id)
|
||||
if not account:
|
||||
logger.error(f"Account {account_id} not found")
|
||||
return None
|
||||
|
||||
parsed_email = parse_email_message(raw_email)
|
||||
|
||||
# Use server-provided message ID if missing
|
||||
if not parsed_email["message_id"]:
|
||||
parsed_email["message_id"] = f"generated-{message_id}"
|
||||
|
||||
message_hash = compute_message_hash(
|
||||
parsed_email["message_id"],
|
||||
parsed_email["subject"],
|
||||
parsed_email["sender"],
|
||||
parsed_email["body"]
|
||||
|
||||
mail_message = create_mail_message(
|
||||
db, account.tags, folder, raw_email, message_id
|
||||
)
|
||||
|
||||
if check_message_exists(db, parsed_email["message_id"], message_hash):
|
||||
logger.debug(f"Message {parsed_email['message_id']} already exists in database")
|
||||
return None
|
||||
|
||||
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
|
||||
mail_message = create_mail_message(db, source_item, parsed_email, folder)
|
||||
|
||||
source_item.vector_ids = vectorize_email(mail_message)
|
||||
|
||||
vectorize_email(mail_message)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Stored embedding for message {parsed_email['message_id']}")
|
||||
logger.info("Vector IDs:")
|
||||
for vector_id in source_item.vector_ids:
|
||||
logger.info(f" - {vector_id}")
|
||||
|
||||
return source_item.id
|
||||
|
||||
logger.info(f"Stored embedding for message {mail_message.message_id}")
|
||||
logger.info("Chunks:")
|
||||
for chunk in mail_message.chunks:
|
||||
logger.info(f" - {chunk.id}")
|
||||
|
||||
return mail_message.id
|
||||
|
||||
|
||||
@app.task(name=SYNC_ACCOUNT)
|
||||
def sync_account(account_id: int) -> dict:
|
||||
"""
|
||||
Synchronize emails from a specific account.
|
||||
|
||||
|
||||
Args:
|
||||
account_id: ID of the EmailAccount to sync
|
||||
|
||||
|
||||
Returns:
|
||||
dict with stats about the sync operation
|
||||
"""
|
||||
logger.info(f"Syncing account {account_id}")
|
||||
|
||||
with make_session() as db:
|
||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||
if not account or not account.active:
|
||||
logger.warning(f"Account {account_id} not found or inactive")
|
||||
return {"error": "Account not found or inactive"}
|
||||
|
||||
|
||||
folders_to_process = account.folders or ["INBOX"]
|
||||
since_date = account.last_sync_at or datetime(1970, 1, 1)
|
||||
|
||||
messages_found = 0
|
||||
new_messages = 0
|
||||
errors = 0
|
||||
|
||||
|
||||
def process_message_wrapper(
|
||||
account_id: int, message_id: str, folder: str, raw_email: str
|
||||
) -> int | None:
|
||||
if check_message_exists(db, account_id, message_id, raw_email):
|
||||
return None
|
||||
return process_message.delay(account_id, message_id, folder, raw_email)
|
||||
|
||||
try:
|
||||
with imap_connection(account) as conn:
|
||||
for folder in folders_to_process:
|
||||
folder_stats = process_folder(conn, folder, account, since_date, process_message.delay)
|
||||
|
||||
folder_stats = process_folder(
|
||||
conn, folder, account, since_date, process_message_wrapper
|
||||
)
|
||||
|
||||
messages_found += folder_stats["messages_found"]
|
||||
new_messages += folder_stats["new_messages"]
|
||||
errors += folder_stats["errors"]
|
||||
|
||||
|
||||
account.last_sync_at = datetime.now()
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
return {
|
||||
"account": account.email_address,
|
||||
"folders_processed": len(folders_to_process),
|
||||
"messages_found": messages_found,
|
||||
"new_messages": new_messages,
|
||||
"errors": errors
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
@ -134,18 +132,18 @@ def sync_account(account_id: int) -> dict:
|
||||
def sync_all_accounts() -> list[dict]:
|
||||
"""
|
||||
Synchronize all active email accounts.
|
||||
|
||||
|
||||
Returns:
|
||||
List of task IDs that were scheduled
|
||||
"""
|
||||
with make_session() as db:
|
||||
active_accounts = db.query(EmailAccount).filter(EmailAccount.active).all()
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"account_id": account.id,
|
||||
"email": account.email_address,
|
||||
"task_id": sync_account.delay(account.id).id
|
||||
"task_id": sync_account.delay(account.id).id,
|
||||
}
|
||||
for account in active_accounts
|
||||
]
|
||||
]
|
||||
|
@ -1,5 +0,0 @@
|
||||
from memory.workers.celery_app import app
|
||||
|
||||
@app.task(name="kb.text.ping")
|
||||
def ping():
|
||||
return "pong"
|
@ -1,5 +0,0 @@
|
||||
from memory.workers.celery_app import app
|
||||
|
||||
@app.task(name="kb.text.ping")
|
||||
def ping():
|
||||
return "pong"
|
@ -1,5 +0,0 @@
|
||||
from memory.workers.celery_app import app
|
||||
|
||||
@app.task(name="kb.text.ping")
|
||||
def ping():
|
||||
return "pong"
|
@ -1,6 +0,0 @@
|
||||
from memory.workers.celery_app import app
|
||||
|
||||
|
||||
@app.task(name="kb.text.ping")
|
||||
def ping():
|
||||
return "pong"
|
@ -1,5 +0,0 @@
|
||||
from memory.workers.celery_app import app
|
||||
|
||||
@app.task(name="memory.text.ping")
|
||||
def ping():
|
||||
return "pong"
|
264
tests/memory/common/parsers/test_email_parsers.py
Normal file
264
tests/memory/common/parsers/test_email_parsers.py
Normal file
@ -0,0 +1,264 @@
|
||||
import email
|
||||
import email.mime.multipart
|
||||
import email.mime.text
|
||||
import email.mime.base
|
||||
|
||||
from datetime import datetime
|
||||
from email.utils import formatdate
|
||||
from unittest.mock import ANY, patch
|
||||
import pytest
|
||||
from memory.common.parsers.email import (
|
||||
compute_message_hash,
|
||||
extract_attachments,
|
||||
extract_body,
|
||||
extract_date,
|
||||
extract_recipients,
|
||||
parse_email_message,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Use a simple counter to generate unique message IDs without calling make_msgid
|
||||
_msg_id_counter = 0
|
||||
def _generate_test_message_id():
|
||||
"""Generate a simple message ID for testing without expensive calls"""
|
||||
global _msg_id_counter
|
||||
_msg_id_counter += 1
|
||||
return f"<test-message-{_msg_id_counter}@example.com>"
|
||||
|
||||
|
||||
def create_email_message(
|
||||
subject="Test Subject",
|
||||
from_addr="sender@example.com",
|
||||
to_addrs="recipient@example.com",
|
||||
cc_addrs=None,
|
||||
bcc_addrs=None,
|
||||
date=None,
|
||||
body="Test body content",
|
||||
attachments=None,
|
||||
multipart=True,
|
||||
message_id=None,
|
||||
):
|
||||
"""Helper function to create email.message.Message objects for testing"""
|
||||
if multipart:
|
||||
msg = email.mime.multipart.MIMEMultipart()
|
||||
msg.attach(email.mime.text.MIMEText(body))
|
||||
|
||||
if attachments:
|
||||
for attachment in attachments:
|
||||
attachment_part = email.mime.base.MIMEBase(
|
||||
"application", "octet-stream"
|
||||
)
|
||||
attachment_part.set_payload(attachment["content"])
|
||||
attachment_part.add_header(
|
||||
"Content-Disposition",
|
||||
f"attachment; filename={attachment['filename']}",
|
||||
)
|
||||
msg.attach(attachment_part)
|
||||
else:
|
||||
msg = email.mime.text.MIMEText(body)
|
||||
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = to_addrs
|
||||
|
||||
if cc_addrs:
|
||||
msg["Cc"] = cc_addrs
|
||||
if bcc_addrs:
|
||||
msg["Bcc"] = bcc_addrs
|
||||
if date:
|
||||
msg["Date"] = formatdate(float(date.timestamp()))
|
||||
if message_id:
|
||||
msg["Message-ID"] = message_id
|
||||
else:
|
||||
msg["Message-ID"] = _generate_test_message_id()
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"to_addr, cc_addr, bcc_addr, expected",
|
||||
[
|
||||
# Single recipient in To field
|
||||
("recipient@example.com", None, None, ["recipient@example.com"]),
|
||||
# Multiple recipients in To field
|
||||
(
|
||||
"recipient1@example.com, recipient2@example.com",
|
||||
None,
|
||||
None,
|
||||
["recipient1@example.com", "recipient2@example.com"],
|
||||
),
|
||||
# To, Cc fields
|
||||
(
|
||||
"recipient@example.com",
|
||||
"cc@example.com",
|
||||
None,
|
||||
["recipient@example.com", "cc@example.com"],
|
||||
),
|
||||
# To, Cc, Bcc fields
|
||||
(
|
||||
"recipient@example.com",
|
||||
"cc@example.com",
|
||||
"bcc@example.com",
|
||||
["recipient@example.com", "cc@example.com", "bcc@example.com"],
|
||||
),
|
||||
# Empty fields
|
||||
("", "", "", []),
|
||||
],
|
||||
)
|
||||
def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected):
|
||||
msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr)
|
||||
assert sorted(extract_recipients(msg)) == sorted(expected)
|
||||
|
||||
|
||||
def test_extract_date_missing():
|
||||
msg = create_email_message(date=None)
|
||||
assert extract_date(msg) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"date_str",
|
||||
[
|
||||
"Invalid Date Format",
|
||||
"2023-01-01", # ISO format but not RFC compliant
|
||||
"Monday, Jan 1, 2023", # Descriptive but not RFC compliant
|
||||
"01/01/2023", # Common format but not RFC compliant
|
||||
"", # Empty string
|
||||
],
|
||||
)
|
||||
def test_extract_date_invalid_formats(date_str):
|
||||
msg = create_email_message()
|
||||
msg["Date"] = date_str
|
||||
assert extract_date(msg) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"date_str",
|
||||
[
|
||||
"Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format
|
||||
"01 Jan 2023 12:00:00 +0000", # RFC 822 format
|
||||
"Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name
|
||||
],
|
||||
)
|
||||
def test_extract_date(date_str):
|
||||
msg = create_email_message()
|
||||
msg["Date"] = date_str
|
||||
result = extract_date(msg)
|
||||
|
||||
assert result is not None
|
||||
assert result.year == 2023
|
||||
assert result.month == 1
|
||||
assert result.day == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multipart", [True, False])
|
||||
def test_extract_body_text_plain(multipart):
|
||||
body_content = "This is a test email body"
|
||||
msg = create_email_message(body=body_content, multipart=multipart)
|
||||
extracted = extract_body(msg)
|
||||
|
||||
# Strip newlines for comparison since multipart emails often add them
|
||||
assert extracted.strip() == body_content.strip()
|
||||
|
||||
|
||||
def test_extract_body_with_attachments():
|
||||
body_content = "This is a test email body"
|
||||
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
|
||||
msg = create_email_message(body=body_content, attachments=attachments)
|
||||
assert body_content in extract_body(msg)
|
||||
|
||||
|
||||
def test_extract_attachments_none():
|
||||
msg = create_email_message(multipart=True)
|
||||
assert extract_attachments(msg) == []
|
||||
|
||||
|
||||
def test_extract_attachments_with_files():
|
||||
attachments = [
|
||||
{"filename": "test1.txt", "content": b"content1"},
|
||||
{"filename": "test2.pdf", "content": b"content2"},
|
||||
]
|
||||
msg = create_email_message(attachments=attachments)
|
||||
|
||||
result = extract_attachments(msg)
|
||||
assert len(result) == 2
|
||||
assert result[0]["filename"] == "test1.txt"
|
||||
assert result[1]["filename"] == "test2.pdf"
|
||||
|
||||
|
||||
def test_extract_attachments_non_multipart():
|
||||
msg = create_email_message(multipart=False)
|
||||
assert extract_attachments(msg) == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msg_id, subject, sender, body, expected",
|
||||
[
|
||||
(
|
||||
"<test@example.com>",
|
||||
"Test Subject",
|
||||
"sender@example.com",
|
||||
"Test body",
|
||||
b"\xf2\xbd", # First two bytes of the actual hash
|
||||
),
|
||||
(
|
||||
"<different@example.com>",
|
||||
"Test Subject",
|
||||
"sender@example.com",
|
||||
"Test body",
|
||||
b"\xa4\x15", # Will be different from the first hash
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compute_message_hash(msg_id, subject, sender, body, expected):
|
||||
result = compute_message_hash(msg_id, subject, sender, body)
|
||||
|
||||
# Verify it's bytes and correct length for SHA-256 (32 bytes)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) == 32
|
||||
|
||||
# Verify first two bytes match expected
|
||||
assert result[:2] == expected
|
||||
|
||||
|
||||
def test_hash_consistency():
|
||||
args = ("<test@example.com>", "Test Subject", "sender@example.com", "Test body")
|
||||
assert compute_message_hash(*args) == compute_message_hash(*args)
|
||||
|
||||
|
||||
def test_parse_simple_email():
|
||||
test_date = datetime(2023, 1, 1, 12, 0, 0)
|
||||
msg_id = "<test123@example.com>"
|
||||
msg = create_email_message(
|
||||
subject="Test Subject",
|
||||
from_addr="sender@example.com",
|
||||
to_addrs="recipient@example.com",
|
||||
date=test_date,
|
||||
body="Test body content",
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
result = parse_email_message(msg.as_string(), msg_id)
|
||||
|
||||
assert result == {
|
||||
"message_id": msg_id,
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": ["recipient@example.com"],
|
||||
"body": "Test body content\n",
|
||||
"attachments": [],
|
||||
"sent_at": ANY,
|
||||
"hash": b'\xed\xa0\x9b\xd4\t4\x06\xb9l\xa4\xb3*\xe4NpZ\x19\xc2\x9b\x87'
|
||||
+ b'\xa6\x12\r\x7fS\xb6\xf1\xbe\x95\x9c\x99\xf1',
|
||||
}
|
||||
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
|
||||
|
||||
|
||||
def test_parse_email_with_attachments():
|
||||
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
|
||||
msg = create_email_message(attachments=attachments)
|
||||
|
||||
result = parse_email_message(msg.as_string(), "123")
|
||||
|
||||
assert len(result["attachments"]) == 1
|
||||
assert result["attachments"][0]["filename"] == "test.txt"
|
285
tests/memory/common/test_chunker.py
Normal file
285
tests/memory/common/test_chunker.py
Normal file
@ -0,0 +1,285 @@
|
||||
import pytest
|
||||
from memory.common.chunker import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []),
|
||||
("hello", ["hello"]),
|
||||
("This is a simple sentence", ["This is a simple sentence"]),
|
||||
("word1 word2", ["word1 word2"]),
|
||||
(" ", []), # Just spaces
|
||||
("\n\t ", []), # Whitespace characters
|
||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||
assert list(yield_word_chunks(text)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
(
|
||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
||||
),
|
||||
(
|
||||
"supercalifragilisticexpialidocious",
|
||||
["supercalifragilisticexpialidocious"],
|
||||
)
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_long_text(text, expected):
|
||||
"""Test chunking with long text that exceeds token limits"""
|
||||
assert list(yield_word_chunks(text, max_tokens=10)) == expected
|
||||
|
||||
|
||||
def test_yield_word_chunk_single_long_word():
|
||||
"""Test behavior with a single word longer than the token limit"""
|
||||
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
|
||||
|
||||
chunks = list(yield_word_chunks(long_word, max_tokens))
|
||||
# With our changes, this should be a single chunk
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == long_word
|
||||
|
||||
|
||||
def test_yield_word_chunk_small_token_limit():
|
||||
"""Test with a very small max_tokens value to force chunking"""
|
||||
text = "one two three four five"
|
||||
max_tokens = 1 # Very small to force chunking after each word
|
||||
|
||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, max_tokens, expected_chunks",
|
||||
[
|
||||
# Empty text
|
||||
("", 10, []),
|
||||
# Text below token limit
|
||||
("hello world", 10, ["hello world"]),
|
||||
# Text right at token limit
|
||||
(
|
||||
"word1 word2", # 11 chars with space
|
||||
3, # 12 chars limit
|
||||
["word1 word2"]
|
||||
),
|
||||
# Text just over token limit should split
|
||||
(
|
||||
"word1 word2 word3", # 17 chars with spaces
|
||||
4, # 16 chars limit
|
||||
["word1 word2 word3"]
|
||||
),
|
||||
# Each word exactly at token limit
|
||||
(
|
||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||
1, # 1 token limit (4 chars)
|
||||
["aaaa", "bbbb", "cccc"]
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||
"""Test different combinations of text and token limits"""
|
||||
assert list(yield_word_chunks(text, max_tokens)) == expected_chunks
|
||||
|
||||
|
||||
def test_yield_word_chunk_real_world_example():
|
||||
"""Test with a realistic text example"""
|
||||
text = (
|
||||
"The yield_word_chunks function splits text into chunks based on word boundaries. "
|
||||
"It tries to maximize chunk size while staying under the specified token limit. "
|
||||
"This behavior is essential for processing large documents efficiently."
|
||||
)
|
||||
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
'The yield_word_chunks function splits text',
|
||||
'into chunks based on word boundaries. It',
|
||||
'tries to maximize chunk size while staying',
|
||||
'under the specified token limit. This',
|
||||
'behavior is essential for processing large',
|
||||
'documents efficiently.',
|
||||
]
|
||||
|
||||
|
||||
# Tests for yield_spans function
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text should yield nothing
|
||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
)
|
||||
def test_yield_spans_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_spans with various inputs"""
|
||||
assert list(yield_spans(text)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_paragraphs():
|
||||
"""Test splitting by paragraphs"""
|
||||
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
|
||||
expected = ["Paragraph one.", "Paragraph two.", "Paragraph three."]
|
||||
assert list(yield_spans(text)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_sentences():
|
||||
"""Test splitting by sentences when paragraphs exceed token limit"""
|
||||
# Create a paragraph that exceeds token limit but sentences are within limit
|
||||
max_tokens = 5 # 20 chars with CHARS_PER_TOKEN = 4
|
||||
sentence1 = "Short sentence one." # ~20 chars
|
||||
sentence2 = "Another short sentence." # ~24 chars
|
||||
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
|
||||
|
||||
# Function should now preserve punctuation
|
||||
expected = ["Short sentence one.", "Another short sentence."]
|
||||
assert list(yield_spans(text, max_tokens)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_words():
|
||||
"""Test splitting by words when sentences exceed token limit"""
|
||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
||||
|
||||
|
||||
def test_yield_spans_complex_document():
|
||||
"""Test with a document containing multiple paragraphs and sentences"""
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
text = (
|
||||
"Paragraph one with a short sentence. And another sentence that should be split.\n\n"
|
||||
"Paragraph two is also here. It has multiple sentences. Some are short. "
|
||||
"This one is longer and might need word splitting depending on the limit.\n\n"
|
||||
"Final short paragraph."
|
||||
)
|
||||
|
||||
assert list(yield_spans(text, max_tokens)) == [
|
||||
"Paragraph one with a short sentence.",
|
||||
"And another sentence that should be split.",
|
||||
"Paragraph two is also here.",
|
||||
"It has multiple sentences.",
|
||||
"Some are short.",
|
||||
"This one is longer and might need word",
|
||||
"splitting depending on the limit.",
|
||||
"Final short paragraph."
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_very_long_word():
|
||||
"""Test with a word that exceeds the token limit"""
|
||||
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
|
||||
|
||||
assert list(yield_spans(long_word, max_tokens)) == [long_word]
|
||||
|
||||
|
||||
def test_yield_spans_with_punctuation():
|
||||
"""Test sentence splitting with various punctuation"""
|
||||
text = "First sentence! Second sentence? Third sentence."
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
||||
|
||||
|
||||
def test_yield_spans_edge_cases():
|
||||
"""Test edge cases like empty paragraphs, single character paragraphs"""
|
||||
text = "\n\nA\n\n\n\nB\n\n"
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text
|
||||
("Short text", ["Short text"]), # Text below token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
)
|
||||
def test_chunk_text_basic_behavior(text, expected):
|
||||
"""Test basic behavior of chunk_text with various inputs"""
|
||||
assert list(chunk_text(text)) == expected
|
||||
|
||||
|
||||
def test_chunk_text_single_paragraph():
|
||||
"""Test chunking a single paragraph that fits within token limit"""
|
||||
text = "This is a simple paragraph that should fit in one chunk."
|
||||
assert list(chunk_text(text, max_tokens=20)) == [text]
|
||||
|
||||
|
||||
def test_chunk_text_multi_paragraph():
|
||||
"""Test chunking multiple paragraphs"""
|
||||
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
|
||||
assert list(chunk_text(text, max_tokens=20)) == [text]
|
||||
|
||||
|
||||
def test_chunk_text_long_text():
|
||||
"""Test chunking with long text that exceeds token limit"""
|
||||
# Create a long text that will need multiple chunks
|
||||
sentences = [f"This is sentence {i:02}." for i in range(50)]
|
||||
text = " ".join(sentences)
|
||||
|
||||
max_tokens = 10 # 10 tokens = ~40 chars
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
||||
] + [
|
||||
'This is sentence 49.'
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_with_overlap():
|
||||
"""Test chunking with overlap between chunks"""
|
||||
# Create text with distinct parts to test overlap
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
||||
|
||||
|
||||
def test_chunk_text_zero_overlap():
|
||||
"""Test chunking with zero overlap"""
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
# 2 tokens = ~8 chars
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
||||
|
||||
|
||||
def test_chunk_text_clean_break():
|
||||
"""Test that chunking attempts to break at sentence boundaries"""
|
||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
max_tokens = 5 # Enough for about 2 sentences
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
||||
|
||||
|
||||
def test_chunk_text_very_long_sentences():
|
||||
"""Test with very long sentences that exceed the token limit"""
|
||||
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
||||
|
||||
max_tokens = 5 # Small limit to force splitting
|
||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||
'This is a very long sentence with many many',
|
||||
'words that will definitely exceed the',
|
||||
'token limit we set for',
|
||||
'this particular test',
|
||||
'case and should be split into multiple',
|
||||
'chunks by the function.',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
]
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
@ -1,321 +1,51 @@
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, approx_token_count, get_modality, embed_text, embed_file, embed_mixed, embed_page, embed
|
||||
from PIL import Image
|
||||
import pathlib
|
||||
from memory.common import settings
|
||||
from memory.common.embedding import (
|
||||
get_modality,
|
||||
embed_text,
|
||||
embed_file,
|
||||
embed_mixed,
|
||||
embed_page,
|
||||
embed,
|
||||
write_to_file,
|
||||
make_chunk,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed(mock_voyage_client):
|
||||
vectors = ([i] for i in range(1000))
|
||||
|
||||
def embed(texts, model):
|
||||
def embed(texts, model, input_type):
|
||||
return Mock(embeddings=[next(vectors) for _ in texts])
|
||||
|
||||
mock_voyage_client.embed.side_effect = embed
|
||||
mock_voyage_client.multimodal_embed.side_effect = embed
|
||||
|
||||
return mock_voyage_client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []),
|
||||
("hello", ["hello"]),
|
||||
("This is a simple sentence", ["This is a simple sentence"]),
|
||||
("word1 word2", ["word1 word2"]),
|
||||
(" ", []), # Just spaces
|
||||
("\n\t ", []), # Whitespace characters
|
||||
("word1 \n word2\t word3", ["word1 word2 word3"]), # Mixed whitespace
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_word_chunks with various inputs"""
|
||||
assert list(yield_word_chunks(text)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
(
|
||||
"word1 word2 word3 word4 verylongwordthatexceedsthelimit word5",
|
||||
['word1 word2 word3 word4', 'verylongwordthatexceedsthelimit word5'],
|
||||
),
|
||||
(
|
||||
"supercalifragilisticexpialidocious",
|
||||
["supercalifragilisticexpialidocious"],
|
||||
)
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_long_text(text, expected):
|
||||
"""Test chunking with long text that exceeds token limits"""
|
||||
assert list(yield_word_chunks(text, max_tokens=10)) == expected
|
||||
|
||||
|
||||
def test_yield_word_chunk_single_long_word():
|
||||
"""Test behavior with a single word longer than the token limit"""
|
||||
max_tokens = 5 # 5 tokens = 20 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "x" * (max_tokens * CHARS_PER_TOKEN * 2) # Word twice as long as max
|
||||
|
||||
chunks = list(yield_word_chunks(long_word, max_tokens))
|
||||
# With our changes, this should be a single chunk
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == long_word
|
||||
|
||||
|
||||
def test_yield_word_chunk_small_token_limit():
|
||||
"""Test with a very small max_tokens value to force chunking"""
|
||||
text = "one two three four five"
|
||||
max_tokens = 1 # Very small to force chunking after each word
|
||||
|
||||
assert list(yield_word_chunks(text, max_tokens)) == ["one two", "three", "four", "five"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, max_tokens, expected_chunks",
|
||||
[
|
||||
# Empty text
|
||||
("", 10, []),
|
||||
# Text below token limit
|
||||
("hello world", 10, ["hello world"]),
|
||||
# Text right at token limit
|
||||
(
|
||||
"word1 word2", # 11 chars with space
|
||||
3, # 12 chars limit
|
||||
["word1 word2"]
|
||||
),
|
||||
# Text just over token limit should split
|
||||
(
|
||||
"word1 word2 word3", # 17 chars with spaces
|
||||
4, # 16 chars limit
|
||||
["word1 word2 word3"]
|
||||
),
|
||||
# Each word exactly at token limit
|
||||
(
|
||||
"aaaa bbbb cccc", # Each word is exactly 4 chars (1 token)
|
||||
1, # 1 token limit (4 chars)
|
||||
["aaaa", "bbbb", "cccc"]
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_yield_word_chunk_various_token_limits(text, max_tokens, expected_chunks):
|
||||
"""Test different combinations of text and token limits"""
|
||||
assert list(yield_word_chunks(text, max_tokens)) == expected_chunks
|
||||
|
||||
|
||||
def test_yield_word_chunk_real_world_example():
|
||||
"""Test with a realistic text example"""
|
||||
text = (
|
||||
"The yield_word_chunks function splits text into chunks based on word boundaries. "
|
||||
"It tries to maximize chunk size while staying under the specified token limit. "
|
||||
"This behavior is essential for processing large documents efficiently."
|
||||
)
|
||||
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
assert list(yield_word_chunks(text, max_tokens)) == [
|
||||
'The yield_word_chunks function splits text',
|
||||
'into chunks based on word boundaries. It',
|
||||
'tries to maximize chunk size while staying',
|
||||
'under the specified token limit. This',
|
||||
'behavior is essential for processing large',
|
||||
'documents efficiently.',
|
||||
]
|
||||
|
||||
|
||||
# Tests for yield_spans function
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text should yield nothing
|
||||
("Simple paragraph", ["Simple paragraph"]), # Single paragraph under token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
)
|
||||
def test_yield_spans_basic_behavior(text, expected):
|
||||
"""Test basic behavior of yield_spans with various inputs"""
|
||||
assert list(yield_spans(text)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_paragraphs():
|
||||
"""Test splitting by paragraphs"""
|
||||
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
|
||||
expected = ["Paragraph one.", "Paragraph two.", "Paragraph three."]
|
||||
assert list(yield_spans(text)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_sentences():
|
||||
"""Test splitting by sentences when paragraphs exceed token limit"""
|
||||
# Create a paragraph that exceeds token limit but sentences are within limit
|
||||
max_tokens = 5 # 20 chars with CHARS_PER_TOKEN = 4
|
||||
sentence1 = "Short sentence one." # ~20 chars
|
||||
sentence2 = "Another short sentence." # ~24 chars
|
||||
text = f"{sentence1} {sentence2}" # Combined exceeds 5 tokens
|
||||
|
||||
# Function should now preserve punctuation
|
||||
expected = ["Short sentence one.", "Another short sentence."]
|
||||
assert list(yield_spans(text, max_tokens)) == expected
|
||||
|
||||
|
||||
def test_yield_spans_words():
|
||||
"""Test splitting by words when sentences exceed token limit"""
|
||||
max_tokens = 3 # 12 chars with CHARS_PER_TOKEN = 4
|
||||
long_sentence = "This sentence has several words and needs word-level chunking."
|
||||
|
||||
assert list(yield_spans(long_sentence, max_tokens)) == ['This sentence', 'has several', 'words and needs', 'word-level', 'chunking.']
|
||||
|
||||
|
||||
def test_yield_spans_complex_document():
|
||||
"""Test with a document containing multiple paragraphs and sentences"""
|
||||
max_tokens = 10 # 40 chars with CHARS_PER_TOKEN = 4
|
||||
text = (
|
||||
"Paragraph one with a short sentence. And another sentence that should be split.\n\n"
|
||||
"Paragraph two is also here. It has multiple sentences. Some are short. "
|
||||
"This one is longer and might need word splitting depending on the limit.\n\n"
|
||||
"Final short paragraph."
|
||||
)
|
||||
|
||||
assert list(yield_spans(text, max_tokens)) == [
|
||||
"Paragraph one with a short sentence.",
|
||||
"And another sentence that should be split.",
|
||||
"Paragraph two is also here.",
|
||||
"It has multiple sentences.",
|
||||
"Some are short.",
|
||||
"This one is longer and might need word",
|
||||
"splitting depending on the limit.",
|
||||
"Final short paragraph."
|
||||
]
|
||||
|
||||
|
||||
def test_yield_spans_very_long_word():
|
||||
"""Test with a word that exceeds the token limit"""
|
||||
max_tokens = 2 # 8 chars with CHARS_PER_TOKEN = 4
|
||||
long_word = "supercalifragilisticexpialidocious" # Much longer than 8 chars
|
||||
|
||||
assert list(yield_spans(long_word, max_tokens)) == [long_word]
|
||||
|
||||
|
||||
def test_yield_spans_with_punctuation():
|
||||
"""Test sentence splitting with various punctuation"""
|
||||
text = "First sentence! Second sentence? Third sentence."
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["First sentence!", "Second sentence?", "Third sentence."]
|
||||
|
||||
|
||||
def test_yield_spans_edge_cases():
|
||||
"""Test edge cases like empty paragraphs, single character paragraphs"""
|
||||
text = "\n\nA\n\n\n\nB\n\n"
|
||||
|
||||
assert list(yield_spans(text, max_tokens=10)) == ["A", "B"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("", []), # Empty text
|
||||
("Short text", ["Short text"]), # Text below token limit
|
||||
(" ", []), # Just whitespace
|
||||
]
|
||||
)
|
||||
def test_chunk_text_basic_behavior(text, expected):
|
||||
"""Test basic behavior of chunk_text with various inputs"""
|
||||
assert list(chunk_text(text)) == expected
|
||||
|
||||
|
||||
def test_chunk_text_single_paragraph():
|
||||
"""Test chunking a single paragraph that fits within token limit"""
|
||||
text = "This is a simple paragraph that should fit in one chunk."
|
||||
assert list(chunk_text(text, max_tokens=20)) == [text]
|
||||
|
||||
|
||||
def test_chunk_text_multi_paragraph():
|
||||
"""Test chunking multiple paragraphs"""
|
||||
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three."
|
||||
assert list(chunk_text(text, max_tokens=20)) == [text]
|
||||
|
||||
|
||||
def test_chunk_text_long_text():
|
||||
"""Test chunking with long text that exceeds token limit"""
|
||||
# Create a long text that will need multiple chunks
|
||||
sentences = [f"This is sentence {i:02}." for i in range(50)]
|
||||
text = " ".join(sentences)
|
||||
|
||||
max_tokens = 10 # 10 tokens = ~40 chars
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=6)) == [
|
||||
f'This is sentence {i:02}. This is sentence {i + 1:02}.' for i in range(49)
|
||||
] + [
|
||||
'This is sentence 49.'
|
||||
]
|
||||
|
||||
|
||||
def test_chunk_text_with_overlap():
|
||||
"""Test chunking with overlap between chunks"""
|
||||
# Create text with distinct parts to test overlap
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
assert list(chunk_text(text, max_tokens=4, overlap=3)) == ['Part A. Part B. Part C.', 'Part C. Part D. Part E.', 'Part E.']
|
||||
|
||||
|
||||
def test_chunk_text_zero_overlap():
|
||||
"""Test chunking with zero overlap"""
|
||||
text = "Part A. Part B. Part C. Part D. Part E."
|
||||
|
||||
# 2 tokens = ~8 chars
|
||||
assert list(chunk_text(text, max_tokens=2, overlap=0)) == ['Part A. Part B.', 'Part C. Part D.', 'Part E.']
|
||||
|
||||
|
||||
def test_chunk_text_clean_break():
|
||||
"""Test that chunking attempts to break at sentence boundaries"""
|
||||
text = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||
|
||||
max_tokens = 5 # Enough for about 2 sentences
|
||||
assert list(chunk_text(text, max_tokens=max_tokens, overlap=3)) == ['First sentence. Second sentence.', 'Third sentence. Fourth sentence.']
|
||||
|
||||
|
||||
def test_chunk_text_very_long_sentences():
|
||||
"""Test with very long sentences that exceed the token limit"""
|
||||
text = "This is a very long sentence with many many words that will definitely exceed the token limit we set for this particular test case and should be split into multiple chunks by the function."
|
||||
|
||||
max_tokens = 5 # Small limit to force splitting
|
||||
assert list(chunk_text(text, max_tokens=max_tokens)) == [
|
||||
'This is a very long sentence with many many',
|
||||
'words that will definitely exceed the',
|
||||
'token limit we set for',
|
||||
'this particular test',
|
||||
'case and should be split into multiple',
|
||||
'chunks by the function.',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string, expected_count",
|
||||
[
|
||||
("", 0),
|
||||
("a" * CHARS_PER_TOKEN, 1),
|
||||
("a" * (CHARS_PER_TOKEN * 2), 2),
|
||||
("a" * (CHARS_PER_TOKEN * 2 + 1), 2), # Truncation
|
||||
("a" * (CHARS_PER_TOKEN - 1), 0), # Truncation
|
||||
]
|
||||
)
|
||||
def test_approx_token_count(string, expected_count):
|
||||
assert approx_token_count(string) == expected_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mime_type, expected_modality",
|
||||
[
|
||||
("text/plain", "doc"),
|
||||
("text/html", "doc"),
|
||||
("text/plain", "text"),
|
||||
("text/html", "blog"),
|
||||
("image/jpeg", "photo"),
|
||||
("image/png", "photo"),
|
||||
("application/pdf", "book"),
|
||||
("application/pdf", "doc"),
|
||||
("application/epub+zip", "book"),
|
||||
("application/mobi", "book"),
|
||||
("application/x-mobipocket-ebook", "book"),
|
||||
("audio/mp3", "unknown"),
|
||||
("video/mp4", "unknown"),
|
||||
("text/something-new", "doc"), # Should match by 'text/' stem
|
||||
("text/something-new", "text"), # Should match by 'text/' stem
|
||||
("image/something-new", "photo"), # Should match by 'image/' stem
|
||||
("custom/format", "unknown"), # No matching stem
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_get_modality(mime_type, expected_modality):
|
||||
assert get_modality(mime_type) == expected_modality
|
||||
@ -329,13 +59,13 @@ def test_embed_text(mock_embed):
|
||||
def test_embed_file(mock_embed, tmp_path):
|
||||
mock_file = tmp_path / "test.txt"
|
||||
mock_file.write_text("file content")
|
||||
|
||||
|
||||
assert embed_file(mock_file) == [[0]]
|
||||
|
||||
|
||||
def test_embed_mixed(mock_embed):
|
||||
items = ["text", {"type": "image", "data": "base64"}]
|
||||
assert embed_mixed(items) == [[0], [1]]
|
||||
assert embed_mixed(items) == [[0]]
|
||||
|
||||
|
||||
def test_embed_page_text_only(mock_embed):
|
||||
@ -345,16 +75,159 @@ def test_embed_page_text_only(mock_embed):
|
||||
|
||||
def test_embed_page_mixed_content(mock_embed):
|
||||
page = {"contents": ["text", {"type": "image", "data": "base64"}]}
|
||||
assert embed_page(page) == [[0], [1]]
|
||||
assert embed_page(page) == [[0]]
|
||||
|
||||
|
||||
def test_embed(mock_embed):
|
||||
mime_type = "text/plain"
|
||||
content = "sample content"
|
||||
metadata = {"source": "test"}
|
||||
|
||||
|
||||
with patch.object(uuid, "uuid4", return_value="id1"):
|
||||
modality, vectors = embed(mime_type, content, metadata)
|
||||
|
||||
assert modality == "doc"
|
||||
assert vectors == [('id1', [0], {'source': 'test'})]
|
||||
modality, chunks = embed(mime_type, content, metadata)
|
||||
|
||||
assert modality == "text"
|
||||
assert [
|
||||
{
|
||||
"id": c.id,
|
||||
"file_path": c.file_path,
|
||||
"content": c.content,
|
||||
"embedding_model": c.embedding_model,
|
||||
"vector": c.vector,
|
||||
"item_metadata": c.item_metadata,
|
||||
}
|
||||
for c in chunks
|
||||
] == [
|
||||
{
|
||||
"content": "sample content",
|
||||
"embedding_model": "voyage-3-large",
|
||||
"file_path": None,
|
||||
"id": "id1",
|
||||
"item_metadata": {"source": "test"},
|
||||
"vector": [0],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_write_to_file_text(mock_file_storage):
|
||||
"""Test writing a string to a file."""
|
||||
chunk_id = "test-chunk-id"
|
||||
content = "This is a test string"
|
||||
|
||||
file_path = write_to_file(chunk_id, content)
|
||||
|
||||
assert file_path == settings.FILE_STORAGE_DIR / f"{chunk_id}.txt"
|
||||
assert file_path.exists()
|
||||
assert file_path.read_text() == content
|
||||
|
||||
|
||||
def test_write_to_file_bytes(mock_file_storage):
|
||||
"""Test writing bytes to a file."""
|
||||
chunk_id = "test-chunk-id"
|
||||
content = b"These are test bytes"
|
||||
|
||||
file_path = write_to_file(chunk_id, content)
|
||||
|
||||
assert file_path == settings.FILE_STORAGE_DIR / f"{chunk_id}.bin"
|
||||
assert file_path.exists()
|
||||
assert file_path.read_bytes() == content
|
||||
|
||||
|
||||
def test_write_to_file_image(mock_file_storage):
|
||||
"""Test writing an image to a file."""
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
chunk_id = "test-chunk-id"
|
||||
|
||||
file_path = write_to_file(chunk_id, img)
|
||||
|
||||
assert file_path == settings.FILE_STORAGE_DIR / f"{chunk_id}.png"
|
||||
assert file_path.exists()
|
||||
# Verify it's a valid image file by opening it
|
||||
image = Image.open(file_path)
|
||||
assert image.size == (100, 100)
|
||||
|
||||
|
||||
def test_write_to_file_unsupported_type(mock_file_storage):
|
||||
"""Test that an error is raised for unsupported content types."""
|
||||
chunk_id = "test-chunk-id"
|
||||
content = 123 # Integer is not a supported type
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported content type"):
|
||||
write_to_file(chunk_id, content)
|
||||
|
||||
|
||||
def test_make_chunk_text_only(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from string content."""
|
||||
page = {
|
||||
"contents": ["text content 1", "text content 2"],
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata)
|
||||
|
||||
assert chunk.id == "00000000-0000-0000-0000-000000000001"
|
||||
assert chunk.content == "text content 1\n\ntext content 2"
|
||||
assert chunk.file_path is None
|
||||
assert chunk.embedding_model == settings.TEXT_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
|
||||
def test_make_chunk_single_image(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from a single image."""
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
page = {"contents": [img], "metadata": {"source": "test"}}
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata)
|
||||
|
||||
assert chunk.id == "00000000-0000-0000-0000-000000000002"
|
||||
assert chunk.content is None
|
||||
assert chunk.file_path == (
|
||||
settings.FILE_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
|
||||
)
|
||||
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
# Verify the file exists
|
||||
assert pathlib.Path(chunk.file_path[0]).exists()
|
||||
|
||||
|
||||
def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from mixed content (string and image)."""
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
page = {"contents": ["text content", img], "metadata": {"source": "test"}}
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
|
||||
):
|
||||
chunk = make_chunk(page, vector, metadata)
|
||||
|
||||
assert chunk.id == "00000000-0000-0000-0000-000000000003"
|
||||
assert chunk.content is None
|
||||
assert chunk.file_path == (
|
||||
settings.FILE_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
|
||||
)
|
||||
assert chunk.embedding_model == settings.MIXED_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
# Verify the files exist
|
||||
assert (
|
||||
settings.FILE_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_0.txt"
|
||||
).exists()
|
||||
assert (
|
||||
settings.FILE_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png"
|
||||
).exists()
|
||||
|
@ -70,7 +70,7 @@ def test_extract_image_with_path(tmp_path):
|
||||
img.save(img_path)
|
||||
|
||||
page, = extract_image(img_path)
|
||||
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["metadata"] == {}
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ def test_extract_image_with_bytes():
|
||||
img_bytes = buffer.getvalue()
|
||||
|
||||
page, = extract_image(img_bytes)
|
||||
assert page["contents"].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes()
|
||||
assert page["metadata"] == {}
|
||||
|
||||
|
||||
@ -119,12 +119,11 @@ def test_extract_content_image(tmp_path):
|
||||
img_path = tmp_path / "test_img.png"
|
||||
img.save(img_path)
|
||||
|
||||
result = extract_content("image/png", img_path)
|
||||
result, = extract_content("image/png", img_path)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0]["contents"], Image.Image)
|
||||
assert result[0]["contents"].size == (100, 100)
|
||||
assert result[0]["metadata"] == {}
|
||||
assert isinstance(result["contents"][0], Image.Image)
|
||||
assert result["contents"][0].size == (100, 100)
|
||||
assert result["metadata"] == {}
|
||||
|
||||
|
||||
def test_extract_content_unsupported_type():
|
||||
|
@ -2,7 +2,12 @@ from unittest import mock
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||
from memory.common.db.models import (
|
||||
EmailAccount,
|
||||
MailMessage,
|
||||
SourceItem,
|
||||
EmailAttachment,
|
||||
)
|
||||
from memory.common import embedding
|
||||
from memory.workers.tasks.email import process_message
|
||||
|
||||
@ -57,7 +62,7 @@ def test_email_account(db_session):
|
||||
use_ssl=True,
|
||||
folders=["INBOX", "Sent", "Archive"],
|
||||
tags=["test", "integration"],
|
||||
active=True
|
||||
active=True,
|
||||
)
|
||||
db_session.add(account)
|
||||
db_session.commit()
|
||||
@ -66,62 +71,58 @@ def test_email_account(db_session):
|
||||
|
||||
def test_process_simple_email(db_session, test_email_account, qdrant):
|
||||
"""Test processing a simple email message."""
|
||||
source_id = process_message(
|
||||
mail_message_id = process_message(
|
||||
account_id=test_email_account.id,
|
||||
message_id="101",
|
||||
folder="INBOX",
|
||||
raw_email=SIMPLE_EMAIL_RAW,
|
||||
)
|
||||
|
||||
assert source_id is not None
|
||||
|
||||
# Check that the source item was created
|
||||
source_item = db_session.query(SourceItem).filter(SourceItem.id == source_id).first()
|
||||
assert source_item is not None
|
||||
assert source_item.modality == "mail"
|
||||
assert source_item.tags == test_email_account.tags
|
||||
assert source_item.mime_type == "message/rfc822"
|
||||
assert source_item.embed_status == "RAW"
|
||||
|
||||
# Check that the mail message was created and linked to the source
|
||||
mail_message = db_session.query(MailMessage).filter(MailMessage.source_id == source_id).first()
|
||||
|
||||
mail_message = (
|
||||
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
|
||||
)
|
||||
assert mail_message is not None
|
||||
assert mail_message.modality == "mail"
|
||||
assert mail_message.tags == test_email_account.tags
|
||||
assert mail_message.mime_type == "message/rfc822"
|
||||
assert mail_message.embed_status == "STORED"
|
||||
assert mail_message.subject == "Test Email 1"
|
||||
assert mail_message.sender == "alice@example.com"
|
||||
assert "bob@example.com" in mail_message.recipients
|
||||
assert "This is test email 1" in mail_message.body_raw
|
||||
assert "This is test email 1" in mail_message.content
|
||||
assert mail_message.folder == "INBOX"
|
||||
|
||||
|
||||
def test_process_email_with_attachment(db_session, test_email_account, qdrant):
|
||||
"""Test processing a message with an attachment."""
|
||||
source_id = process_message(
|
||||
mail_message_id = process_message(
|
||||
account_id=test_email_account.id,
|
||||
message_id="302",
|
||||
folder="Archive",
|
||||
raw_email=EMAIL_WITH_ATTACHMENT_RAW,
|
||||
)
|
||||
|
||||
assert source_id is not None
|
||||
|
||||
# Check mail message specifics
|
||||
mail_message = db_session.query(MailMessage).filter(MailMessage.source_id == source_id).first()
|
||||
mail_message = (
|
||||
db_session.query(MailMessage).filter(MailMessage.id == mail_message_id).one()
|
||||
)
|
||||
assert mail_message is not None
|
||||
assert mail_message.subject == "Email with Attachment"
|
||||
assert mail_message.sender == "eve@example.com"
|
||||
assert "This email has an attachment" in mail_message.body_raw
|
||||
assert "This email has an attachment" in mail_message.content
|
||||
assert mail_message.folder == "Archive"
|
||||
|
||||
|
||||
# Check attachments were processed and stored in the EmailAttachment table
|
||||
attachments = db_session.query(EmailAttachment).filter(
|
||||
EmailAttachment.mail_message_id == mail_message.id
|
||||
).all()
|
||||
|
||||
assert len(attachments) > 0
|
||||
assert attachments[0].filename == "test.txt"
|
||||
assert attachments[0].content_type == "text/plain"
|
||||
# Either content or file_path should be set
|
||||
assert attachments[0].content is not None or attachments[0].file_path is not None
|
||||
attachments = (
|
||||
db_session.query(
|
||||
EmailAttachment.filename,
|
||||
EmailAttachment.content,
|
||||
EmailAttachment.mime_type,
|
||||
)
|
||||
.filter(EmailAttachment.mail_message_id == mail_message.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
assert attachments == [(None, "This is a test attachment", "text/plain")]
|
||||
|
||||
|
||||
def test_process_empty_message(db_session, test_email_account, qdrant):
|
||||
@ -132,7 +133,7 @@ def test_process_empty_message(db_session, test_email_account, qdrant):
|
||||
folder="Archive",
|
||||
raw_email="",
|
||||
)
|
||||
|
||||
|
||||
assert source_id is None
|
||||
|
||||
|
||||
@ -145,13 +146,13 @@ def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
||||
folder="INBOX",
|
||||
raw_email=SIMPLE_EMAIL_RAW,
|
||||
)
|
||||
|
||||
|
||||
assert source_id_1 is not None, "First call should return a source_id"
|
||||
|
||||
|
||||
# Count records to verify state before second call
|
||||
source_count_before = db_session.query(SourceItem).count()
|
||||
message_count_before = db_session.query(MailMessage).count()
|
||||
|
||||
|
||||
# Second call with same email should detect duplicate and return None
|
||||
source_id_2 = process_message(
|
||||
account_id=test_email_account.id,
|
||||
@ -159,12 +160,16 @@ def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
||||
folder="INBOX",
|
||||
raw_email=SIMPLE_EMAIL_RAW,
|
||||
)
|
||||
|
||||
|
||||
assert source_id_2 is None, "Second call should return None for duplicate message"
|
||||
|
||||
|
||||
# Verify no new records were created
|
||||
source_count_after = db_session.query(SourceItem).count()
|
||||
message_count_after = db_session.query(MailMessage).count()
|
||||
|
||||
assert source_count_before == source_count_after, "No new SourceItem should be created"
|
||||
assert message_count_before == message_count_after, "No new MailMessage should be created"
|
||||
|
||||
assert source_count_before == source_count_after, (
|
||||
"No new SourceItem should be created"
|
||||
)
|
||||
assert message_count_before == message_count_after, (
|
||||
"No new MailMessage should be created"
|
||||
)
|
||||
|
@ -9,19 +9,16 @@ from datetime import datetime
|
||||
from email.utils import formatdate
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
import pytest
|
||||
from memory.common.db.models import SourceItem, MailMessage, EmailAttachment, EmailAccount
|
||||
from memory.common.db.models import (
|
||||
SourceItem,
|
||||
MailMessage,
|
||||
EmailAttachment,
|
||||
EmailAccount,
|
||||
)
|
||||
from memory.common import settings
|
||||
from memory.common import embedding
|
||||
from memory.workers.email import (
|
||||
compute_message_hash,
|
||||
create_source_item,
|
||||
extract_attachments,
|
||||
extract_body,
|
||||
extract_date,
|
||||
extract_email_uid,
|
||||
extract_recipients,
|
||||
parse_email_message,
|
||||
check_message_exists,
|
||||
create_mail_message,
|
||||
fetch_email,
|
||||
fetch_email_since,
|
||||
@ -35,6 +32,7 @@ from memory.workers.email import (
|
||||
@pytest.fixture
|
||||
def mock_uuid4():
|
||||
i = 0
|
||||
|
||||
def uuid4():
|
||||
nonlocal i
|
||||
i += 1
|
||||
@ -44,181 +42,6 @@ def mock_uuid4():
|
||||
yield
|
||||
|
||||
|
||||
# Use a simple counter to generate unique message IDs without calling make_msgid
|
||||
_msg_id_counter = 0
|
||||
|
||||
|
||||
def _generate_test_message_id():
|
||||
"""Generate a simple message ID for testing without expensive calls"""
|
||||
global _msg_id_counter
|
||||
_msg_id_counter += 1
|
||||
return f"<test-message-{_msg_id_counter}@example.com>"
|
||||
|
||||
|
||||
def create_email_message(
|
||||
subject="Test Subject",
|
||||
from_addr="sender@example.com",
|
||||
to_addrs="recipient@example.com",
|
||||
cc_addrs=None,
|
||||
bcc_addrs=None,
|
||||
date=None,
|
||||
body="Test body content",
|
||||
attachments=None,
|
||||
multipart=True,
|
||||
message_id=None,
|
||||
):
|
||||
"""Helper function to create email.message.Message objects for testing"""
|
||||
if multipart:
|
||||
msg = email.mime.multipart.MIMEMultipart()
|
||||
msg.attach(email.mime.text.MIMEText(body))
|
||||
|
||||
if attachments:
|
||||
for attachment in attachments:
|
||||
attachment_part = email.mime.base.MIMEBase(
|
||||
"application", "octet-stream"
|
||||
)
|
||||
attachment_part.set_payload(attachment["content"])
|
||||
attachment_part.add_header(
|
||||
"Content-Disposition",
|
||||
f"attachment; filename={attachment['filename']}",
|
||||
)
|
||||
msg.attach(attachment_part)
|
||||
else:
|
||||
msg = email.mime.text.MIMEText(body)
|
||||
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = to_addrs
|
||||
|
||||
if cc_addrs:
|
||||
msg["Cc"] = cc_addrs
|
||||
if bcc_addrs:
|
||||
msg["Bcc"] = bcc_addrs
|
||||
if date:
|
||||
msg["Date"] = formatdate(float(date.timestamp()))
|
||||
if message_id:
|
||||
msg["Message-ID"] = message_id
|
||||
else:
|
||||
msg["Message-ID"] = _generate_test_message_id()
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"to_addr, cc_addr, bcc_addr, expected",
|
||||
[
|
||||
# Single recipient in To field
|
||||
("recipient@example.com", None, None, ["recipient@example.com"]),
|
||||
# Multiple recipients in To field
|
||||
(
|
||||
"recipient1@example.com, recipient2@example.com",
|
||||
None,
|
||||
None,
|
||||
["recipient1@example.com", "recipient2@example.com"],
|
||||
),
|
||||
# To, Cc fields
|
||||
(
|
||||
"recipient@example.com",
|
||||
"cc@example.com",
|
||||
None,
|
||||
["recipient@example.com", "cc@example.com"],
|
||||
),
|
||||
# To, Cc, Bcc fields
|
||||
(
|
||||
"recipient@example.com",
|
||||
"cc@example.com",
|
||||
"bcc@example.com",
|
||||
["recipient@example.com", "cc@example.com", "bcc@example.com"],
|
||||
),
|
||||
# Empty fields
|
||||
("", "", "", []),
|
||||
],
|
||||
)
|
||||
def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected):
|
||||
msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr)
|
||||
assert sorted(extract_recipients(msg)) == sorted(expected)
|
||||
|
||||
|
||||
def test_extract_date_missing():
|
||||
msg = create_email_message(date=None)
|
||||
assert extract_date(msg) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"date_str",
|
||||
[
|
||||
"Invalid Date Format",
|
||||
"2023-01-01", # ISO format but not RFC compliant
|
||||
"Monday, Jan 1, 2023", # Descriptive but not RFC compliant
|
||||
"01/01/2023", # Common format but not RFC compliant
|
||||
"", # Empty string
|
||||
],
|
||||
)
|
||||
def test_extract_date_invalid_formats(date_str):
|
||||
msg = create_email_message()
|
||||
msg["Date"] = date_str
|
||||
assert extract_date(msg) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"date_str",
|
||||
[
|
||||
"Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format
|
||||
"01 Jan 2023 12:00:00 +0000", # RFC 822 format
|
||||
"Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name
|
||||
],
|
||||
)
|
||||
def test_extract_date(date_str):
|
||||
msg = create_email_message()
|
||||
msg["Date"] = date_str
|
||||
result = extract_date(msg)
|
||||
|
||||
assert result is not None
|
||||
assert result.year == 2023
|
||||
assert result.month == 1
|
||||
assert result.day == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multipart", [True, False])
|
||||
def test_extract_body_text_plain(multipart):
|
||||
body_content = "This is a test email body"
|
||||
msg = create_email_message(body=body_content, multipart=multipart)
|
||||
extracted = extract_body(msg)
|
||||
|
||||
# Strip newlines for comparison since multipart emails often add them
|
||||
assert extracted.strip() == body_content.strip()
|
||||
|
||||
|
||||
def test_extract_body_with_attachments():
|
||||
body_content = "This is a test email body"
|
||||
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
|
||||
msg = create_email_message(body=body_content, attachments=attachments)
|
||||
assert body_content in extract_body(msg)
|
||||
|
||||
|
||||
def test_extract_attachments_none():
|
||||
msg = create_email_message(multipart=True)
|
||||
assert extract_attachments(msg) == []
|
||||
|
||||
|
||||
def test_extract_attachments_with_files():
|
||||
attachments = [
|
||||
{"filename": "test1.txt", "content": b"content1"},
|
||||
{"filename": "test2.pdf", "content": b"content2"},
|
||||
]
|
||||
msg = create_email_message(attachments=attachments)
|
||||
|
||||
result = extract_attachments(msg)
|
||||
assert len(result) == 2
|
||||
assert result[0]["filename"] == "test1.txt"
|
||||
assert result[1]["filename"] == "test2.pdf"
|
||||
|
||||
|
||||
def test_extract_attachments_non_multipart():
|
||||
msg = create_email_message(multipart=False)
|
||||
assert extract_attachments(msg) == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attachment_size, max_inline_size, message_id",
|
||||
[
|
||||
@ -237,7 +60,6 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
|
||||
}
|
||||
message = MailMessage(
|
||||
id=1,
|
||||
source=SourceItem(tags=["test"]),
|
||||
message_id=message_id,
|
||||
sender="sender@example.com",
|
||||
folder="INBOX",
|
||||
@ -247,12 +69,8 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
|
||||
result = process_attachment(attachment, message)
|
||||
|
||||
assert result is not None
|
||||
# For inline attachments, content should be base64 encoded string
|
||||
assert isinstance(result.content, bytes)
|
||||
# Decode the base64 string and compare with the original content
|
||||
decoded_content = base64.b64decode(result.content)
|
||||
assert decoded_content == attachment["content"]
|
||||
assert result.file_path is None
|
||||
assert result.content == attachment["content"].decode("utf-8", errors="replace")
|
||||
assert result.filename is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -266,14 +84,13 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
|
||||
)
|
||||
def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
|
||||
attachment = {
|
||||
"filename": "test.txt",
|
||||
"filename": "test/with:special\\chars.txt",
|
||||
"content_type": "text/plain",
|
||||
"size": attachment_size,
|
||||
"content": b"a" * attachment_size,
|
||||
}
|
||||
message = MailMessage(
|
||||
id=1,
|
||||
source=SourceItem(tags=["test"]),
|
||||
message_id=message_id,
|
||||
sender="sender@example.com",
|
||||
folder="INBOX",
|
||||
@ -283,7 +100,12 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
|
||||
|
||||
assert result is not None
|
||||
assert not result.content
|
||||
assert result.file_path == str(settings.FILE_STORAGE_DIR / "sender@example.com" / "INBOX" / "test.txt")
|
||||
assert result.filename == str(
|
||||
settings.FILE_STORAGE_DIR
|
||||
/ "sender_example_com"
|
||||
/ "INBOX"
|
||||
/ "test_with_special_chars.txt"
|
||||
)
|
||||
|
||||
|
||||
def test_process_attachment_write_error():
|
||||
@ -296,7 +118,6 @@ def test_process_attachment_write_error():
|
||||
}
|
||||
message = MailMessage(
|
||||
id=1,
|
||||
source=SourceItem(tags=["test"]),
|
||||
message_id="<test@example.com>",
|
||||
sender="sender@example.com",
|
||||
folder="INBOX",
|
||||
@ -344,7 +165,7 @@ def test_process_attachments_mixed():
|
||||
]
|
||||
message = MailMessage(
|
||||
id=1,
|
||||
source=SourceItem(tags=["test"]),
|
||||
tags=["test"],
|
||||
message_id="<test@example.com>",
|
||||
sender="sender@example.com",
|
||||
folder="INBOX",
|
||||
@ -357,84 +178,14 @@ def test_process_attachments_mixed():
|
||||
# Verify we have all attachments processed
|
||||
assert len(results) == 3
|
||||
|
||||
# Verify small attachments are base64 encoded
|
||||
assert isinstance(results[0].content, bytes)
|
||||
assert isinstance(results[2].content, bytes)
|
||||
assert results[0].content == "a" * 20
|
||||
assert results[2].content == "c" * 30
|
||||
|
||||
# Verify large attachment has a path
|
||||
assert results[1].file_path is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msg_id, subject, sender, body, expected",
|
||||
[
|
||||
(
|
||||
"<test@example.com>",
|
||||
"Test Subject",
|
||||
"sender@example.com",
|
||||
"Test body",
|
||||
b"\xf2\xbd", # First two bytes of the actual hash
|
||||
),
|
||||
(
|
||||
"<different@example.com>",
|
||||
"Test Subject",
|
||||
"sender@example.com",
|
||||
"Test body",
|
||||
b"\xa4\x15", # Will be different from the first hash
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compute_message_hash(msg_id, subject, sender, body, expected):
|
||||
result = compute_message_hash(msg_id, subject, sender, body)
|
||||
|
||||
# Verify it's bytes and correct length for SHA-256 (32 bytes)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) == 32
|
||||
|
||||
# Verify first two bytes match expected
|
||||
assert result[:2] == expected
|
||||
|
||||
|
||||
def test_hash_consistency():
|
||||
args = ("<test@example.com>", "Test Subject", "sender@example.com", "Test body")
|
||||
assert compute_message_hash(*args) == compute_message_hash(*args)
|
||||
|
||||
|
||||
def test_parse_simple_email():
|
||||
test_date = datetime(2023, 1, 1, 12, 0, 0)
|
||||
msg_id = "<test123@example.com>"
|
||||
msg = create_email_message(
|
||||
subject="Test Subject",
|
||||
from_addr="sender@example.com",
|
||||
to_addrs="recipient@example.com",
|
||||
date=test_date,
|
||||
body="Test body content",
|
||||
message_id=msg_id,
|
||||
assert results[1].filename == str(
|
||||
settings.FILE_STORAGE_DIR / "sender_example_com" / "INBOX" / "large.txt"
|
||||
)
|
||||
|
||||
result = parse_email_message(msg.as_string())
|
||||
|
||||
assert result == {
|
||||
"message_id": msg_id,
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": ["recipient@example.com"],
|
||||
"body": "Test body content\n",
|
||||
"attachments": [],
|
||||
"sent_at": ANY,
|
||||
}
|
||||
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
|
||||
|
||||
|
||||
def test_parse_email_with_attachments():
|
||||
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
|
||||
msg = create_email_message(attachments=attachments)
|
||||
|
||||
result = parse_email_message(msg.as_string())
|
||||
|
||||
assert len(result["attachments"]) == 1
|
||||
assert result["attachments"][0]["filename"] == "test.txt"
|
||||
|
||||
|
||||
def test_extract_email_uid_valid():
|
||||
msg_data = [(b"1 (UID 12345 RFC822 {1234}", b"raw email content")]
|
||||
@ -452,148 +203,55 @@ def test_extract_email_uid_no_match():
|
||||
assert raw_email == b"raw email content"
|
||||
|
||||
|
||||
def test_create_source_item(db_session):
|
||||
# Mock data
|
||||
message_hash = b"test_hash_bytes" + bytes(28) # 32 bytes for SHA-256
|
||||
account_tags = ["work", "important"]
|
||||
raw_email_size = 1024
|
||||
|
||||
# Call function
|
||||
source_item = create_source_item(
|
||||
db_session=db_session,
|
||||
message_hash=message_hash,
|
||||
account_tags=account_tags,
|
||||
raw_size=raw_email_size,
|
||||
)
|
||||
|
||||
# Verify the source item was created correctly
|
||||
assert isinstance(source_item, SourceItem)
|
||||
assert source_item.id is not None
|
||||
assert source_item.modality == "mail"
|
||||
assert source_item.sha256 == message_hash
|
||||
assert source_item.tags == account_tags
|
||||
assert source_item.byte_length == raw_email_size
|
||||
assert source_item.mime_type == "message/rfc822"
|
||||
assert source_item.embed_status == "RAW"
|
||||
|
||||
# Verify it was added to the session
|
||||
db_session.flush()
|
||||
fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one()
|
||||
assert fetched_item is not None
|
||||
assert fetched_item.sha256 == message_hash
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setup_db, message_id, message_hash, expected_exists",
|
||||
[
|
||||
# Test by message ID
|
||||
(
|
||||
lambda db: (
|
||||
# First create source_item to satisfy foreign key constraint
|
||||
db.add(
|
||||
SourceItem(
|
||||
id=1,
|
||||
modality="mail",
|
||||
sha256=b"some_hash_bytes" + bytes(28),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
)
|
||||
),
|
||||
db.flush(),
|
||||
# Then create mail_message
|
||||
db.add(
|
||||
MailMessage(
|
||||
source_id=1,
|
||||
message_id="<test@example.com>",
|
||||
subject="Test",
|
||||
sender="test@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
body_raw="Test body",
|
||||
)
|
||||
),
|
||||
),
|
||||
"<test@example.com>",
|
||||
b"unmatched_hash",
|
||||
True,
|
||||
),
|
||||
# Test by non-existent message ID
|
||||
(lambda db: None, "<nonexistent@example.com>", b"unmatched_hash", False),
|
||||
# Test by hash
|
||||
(
|
||||
lambda db: db.add(
|
||||
SourceItem(
|
||||
modality="mail",
|
||||
sha256=b"test_hash_bytes" + bytes(28),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
)
|
||||
),
|
||||
"",
|
||||
b"test_hash_bytes" + bytes(28),
|
||||
True,
|
||||
),
|
||||
# Test by non-existent hash
|
||||
(lambda db: None, "", b"different_hash_" + bytes(28), False),
|
||||
],
|
||||
)
|
||||
def test_check_message_exists(
|
||||
db_session, setup_db, message_id, message_hash, expected_exists
|
||||
):
|
||||
# Setup test data
|
||||
if setup_db:
|
||||
setup_db(db_session)
|
||||
db_session.flush()
|
||||
|
||||
# Test the function
|
||||
assert check_message_exists(db_session, message_id, message_hash) == expected_exists
|
||||
|
||||
|
||||
def test_create_mail_message(db_session):
|
||||
source_item = SourceItem(
|
||||
modality="mail",
|
||||
sha256=b"test_hash_bytes" + bytes(28),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
raw_email = (
|
||||
"From: sender@example.com\n"
|
||||
"To: recipient@example.com\n"
|
||||
"Subject: Test Subject\n"
|
||||
"Date: Sun, 1 Jan 2023 12:00:00 +0000\n"
|
||||
"Message-ID: 321\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
'Content-Type: multipart/mixed; boundary="boundary"\n'
|
||||
"\n"
|
||||
"--boundary\n"
|
||||
"Content-Type: text/plain\n"
|
||||
"\n"
|
||||
"Test body content\n"
|
||||
"--boundary\n"
|
||||
'Content-Disposition: attachment; filename="test.txt"\n'
|
||||
"Content-Type: text/plain\n"
|
||||
"Content-Transfer-Encoding: base64\n"
|
||||
"\n"
|
||||
"YXR0YWNobWVudCBjb250ZW50\n"
|
||||
"--boundary--"
|
||||
)
|
||||
db_session.add(source_item)
|
||||
db_session.flush()
|
||||
parsed_email = {
|
||||
"message_id": "<test@example.com>",
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": ["recipient@example.com"],
|
||||
"sent_at": datetime(2023, 1, 1, 12, 0, 0),
|
||||
"body": "Test body content",
|
||||
"attachments": [
|
||||
{"filename": "test.txt", "content_type": "text/plain", "size": 100}
|
||||
],
|
||||
}
|
||||
folder = "INBOX"
|
||||
|
||||
# Call function
|
||||
mail_message = create_mail_message(
|
||||
db_session=db_session,
|
||||
source_item=source_item,
|
||||
parsed_email=parsed_email,
|
||||
raw_email=raw_email,
|
||||
folder=folder,
|
||||
tags=["test"],
|
||||
message_id="123",
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
attachments = db_session.query(EmailAttachment).filter(EmailAttachment.mail_message_id == mail_message.id).all()
|
||||
attachments = (
|
||||
db_session.query(EmailAttachment)
|
||||
.filter(EmailAttachment.mail_message_id == mail_message.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Verify the mail message was created correctly
|
||||
assert isinstance(mail_message, MailMessage)
|
||||
assert mail_message.source_id == source_item.id
|
||||
assert mail_message.message_id == parsed_email["message_id"]
|
||||
assert mail_message.subject == parsed_email["subject"]
|
||||
assert mail_message.sender == parsed_email["sender"]
|
||||
assert mail_message.recipients == parsed_email["recipients"]
|
||||
assert mail_message.sent_at.isoformat()[:-6] == parsed_email["sent_at"].isoformat()
|
||||
assert mail_message.body_raw == parsed_email["body"]
|
||||
assert mail_message.message_id == "321"
|
||||
assert mail_message.subject == "Test Subject"
|
||||
assert mail_message.sender == "sender@example.com"
|
||||
assert mail_message.recipients == ["recipient@example.com"]
|
||||
assert mail_message.sent_at.isoformat()[:-6] == "2023-01-01T12:00:00"
|
||||
assert mail_message.content == raw_email
|
||||
assert mail_message.body == "Test body content\n"
|
||||
assert mail_message.attachments == attachments
|
||||
|
||||
|
||||
@ -675,99 +333,97 @@ def test_process_folder_error(email_provider):
|
||||
|
||||
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||
mail_message = MailMessage(
|
||||
source=SourceItem(
|
||||
modality="mail",
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
),
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
message_id="<test-vector@example.com>",
|
||||
subject="Test Vectorization",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
body_raw="This is a test email for vectorization",
|
||||
content="This is a test email for vectorization",
|
||||
folder="INBOX",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
assert mail_message.embed_status == "RAW"
|
||||
|
||||
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
|
||||
vector_ids = vectorize_email(mail_message)
|
||||
|
||||
assert len(vector_ids) == 1
|
||||
assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
|
||||
vectorize_email(mail_message)
|
||||
assert [c.id for c in mail_message.chunks] == [
|
||||
"00000000-0000-0000-0000-000000000001"
|
||||
]
|
||||
|
||||
db_session.commit()
|
||||
assert mail_message.embed_status == "STORED"
|
||||
|
||||
|
||||
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||
mail_message = MailMessage(
|
||||
source=SourceItem(
|
||||
modality="mail",
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
),
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
size=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
message_id="<test-vector-attach@example.com>",
|
||||
subject="Test Vectorization with Attachments",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
body_raw="This is a test email with attachments",
|
||||
content="This is a test email with attachments",
|
||||
folder="INBOX",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
# Add two attachments - one with content and one with file_path
|
||||
attachment1 = EmailAttachment(
|
||||
mail_message_id=mail_message.id,
|
||||
filename="inline.txt",
|
||||
content_type="text/plain",
|
||||
size=100,
|
||||
content=base64.b64encode(b"This is inline content"),
|
||||
file_path=None,
|
||||
source=SourceItem(
|
||||
modality="doc",
|
||||
sha256=b"test_hash1" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="text/plain",
|
||||
embed_status="RAW",
|
||||
),
|
||||
filename=None,
|
||||
modality="doc",
|
||||
sha256=b"test_hash1" + bytes(24),
|
||||
tags=["test"],
|
||||
mime_type="text/plain",
|
||||
embed_status="RAW",
|
||||
)
|
||||
|
||||
|
||||
file_path = mail_message.attachments_path / "stored.txt"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_bytes(b"This is stored content")
|
||||
attachment2 = EmailAttachment(
|
||||
mail_message_id=mail_message.id,
|
||||
filename="stored.txt",
|
||||
content_type="text/plain",
|
||||
size=200,
|
||||
content=None,
|
||||
file_path=str(file_path),
|
||||
source=SourceItem(
|
||||
modality="doc",
|
||||
sha256=b"test_hash2" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="text/plain",
|
||||
embed_status="RAW",
|
||||
),
|
||||
filename=str(file_path),
|
||||
modality="doc",
|
||||
sha256=b"test_hash2" + bytes(24),
|
||||
tags=["test"],
|
||||
mime_type="text/plain",
|
||||
embed_status="RAW",
|
||||
)
|
||||
|
||||
|
||||
db_session.add_all([attachment1, attachment2])
|
||||
db_session.flush()
|
||||
|
||||
|
||||
# Mock embedding functions but use real qdrant
|
||||
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
|
||||
# Call the function
|
||||
vector_ids = vectorize_email(mail_message)
|
||||
|
||||
vectorize_email(mail_message)
|
||||
|
||||
# Verify results
|
||||
assert len(vector_ids) == 3
|
||||
assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
|
||||
assert vector_ids[1] == "doc/00000000-0000-0000-0000-000000000002"
|
||||
assert vector_ids[2] == "doc/00000000-0000-0000-0000-000000000003"
|
||||
vector_ids = [
|
||||
c.id for c in mail_message.chunks + attachment1.chunks + attachment2.chunks
|
||||
]
|
||||
assert vector_ids == [
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
"00000000-0000-0000-0000-000000000002",
|
||||
"00000000-0000-0000-0000-000000000003",
|
||||
]
|
||||
|
||||
db_session.commit()
|
||||
assert mail_message.embed_status == "STORED"
|
||||
assert attachment1.embed_status == "STORED"
|
||||
assert attachment2.embed_status == "STORED"
|
||||
|
Loading…
x
Reference in New Issue
Block a user