alembic + tests

This commit is contained in:
Daniel O'Connell 2025-04-27 17:38:04 +02:00
parent a003ada9b7
commit d1cac9ffd9
19 changed files with 1715 additions and 242 deletions

11
.gitignore vendored
View File

@ -3,3 +3,14 @@
secrets/ secrets/
acme.json acme.json
__pycache__/ __pycache__/
*.egg-info/
*.pyc
*.pyo
*.pyd
*.pyw
*.pyz
*.pywz
*.pyzw
docker-compose.override.yml
docker/pgadmin

107
db/migrations/alembic.ini Normal file
View File

@ -0,0 +1,107 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = db/migrations
# template used to generate migration files
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d%%(second).2d_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
version_locations = db/migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# URL of the database to connect to
sqlalchemy.url = driver://user:pass@localhost/dbname
# enable compare_type (to detect column type changes)
# and compare_server_default (to detect changed in server default values)
compare_type = true
compare_server_default = true
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

82
db/migrations/env.py Normal file
View File

@ -0,0 +1,82 @@
"""
Alembic environment configuration.
"""
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from memory.common import settings
from memory.common.db.models import Base
# this is the Alembic Config object
config = context.config
# setup database URL from environment variables
config.set_main_option("sqlalchemy.url", settings.DB_URL)
# Interpret the config file for Python logging
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode - creates SQL scripts.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""
Run migrations in 'online' mode - directly to the database.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,341 @@
"""Initial structure
Revision ID: a466a07360d5
Revises:
Create Date: 2025-04-27 17:15:37.487616
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "a466a07360d5"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute('CREATE EXTENSION IF NOT EXISTS pgcrypto')
# Create enum type for github_item with IF NOT EXISTS
op.execute("DO $$ BEGIN CREATE TYPE gh_item_kind AS ENUM ('issue','pr','comment','project_card'); EXCEPTION WHEN duplicate_object THEN NULL; END $$;")
op.create_table(
"email_accounts",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("email_address", sa.Text(), nullable=False),
sa.Column("imap_server", sa.Text(), nullable=False),
sa.Column("imap_port", sa.Integer(), server_default="993", nullable=False),
sa.Column("username", sa.Text(), nullable=False),
sa.Column("password", sa.Text(), nullable=False),
sa.Column("use_ssl", sa.Boolean(), server_default="true", nullable=False),
sa.Column("folders", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("active", sa.Boolean(), server_default="true", nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("email_address"),
)
op.create_table(
"rss_feeds",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("url", sa.Text(), nullable=False),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
sa.Column("last_checked_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("active", sa.Boolean(), server_default="true", nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("url"),
)
op.create_table(
"source_item",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("modality", sa.Text(), nullable=False),
sa.Column("sha256", postgresql.BYTEA(), nullable=False),
sa.Column(
"inserted_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
sa.Column("lang", sa.Text(), nullable=True),
sa.Column("model_hash", sa.Text(), nullable=True),
sa.Column(
"vector_ids", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
),
sa.Column("embed_status", sa.Text(), server_default="RAW", nullable=False),
sa.Column("byte_length", sa.Integer(), nullable=True),
sa.Column("mime_type", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("sha256"),
)
op.create_table(
"blog_post",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("url", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("url"),
)
op.create_table(
"book_doc",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("chapter", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"chat_message",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("platform", sa.Text(), nullable=True),
sa.Column("channel_id", sa.Text(), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("body_raw", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"git_commit",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("repo_path", sa.Text(), nullable=True),
sa.Column("commit_sha", sa.Text(), nullable=True),
sa.Column("author_name", sa.Text(), nullable=True),
sa.Column("author_email", sa.Text(), nullable=True),
sa.Column("author_date", sa.DateTime(timezone=True), nullable=True),
sa.Column("msg_raw", sa.Text(), nullable=True),
sa.Column("diff_summary", sa.Text(), nullable=True),
sa.Column("files_changed", sa.ARRAY(sa.Text()), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("commit_sha"),
)
op.create_table(
"mail_message",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.Text(), nullable=True),
sa.Column("subject", sa.Text(), nullable=True),
sa.Column("sender", sa.Text(), nullable=True),
sa.Column("recipients", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("body_raw", sa.Text(), nullable=True),
sa.Column(
"attachments", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.Column("tsv", postgresql.TSVECTOR(), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("message_id"),
)
op.create_table(
"misc_doc",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("path", sa.Text(), nullable=True),
sa.Column("mime_type", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"photo",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("exif_taken_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("exif_lat", sa.Numeric(9, 6), nullable=True),
sa.Column("exif_lon", sa.Numeric(9, 6), nullable=True),
sa.Column("camera", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Add github_item table
op.create_table(
"github_item",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("kind", sa.Text(), nullable=False),
sa.Column("repo_path", sa.Text(), nullable=False),
sa.Column("number", sa.Integer(), nullable=True),
sa.Column("parent_number", sa.Integer(), nullable=True),
sa.Column("commit_sha", sa.Text(), nullable=True),
sa.Column("state", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("body_raw", sa.Text(), nullable=True),
sa.Column("labels", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("closed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("merged_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("diff_summary", sa.Text(), nullable=True),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Add constraint to github_item.kind
op.create_check_constraint(
"github_item_kind_check",
"github_item",
"kind IN ('issue', 'pr', 'comment', 'project_card')"
)
# Add missing constraint to source_item
op.create_check_constraint(
"source_item_embed_status_check",
"source_item",
"embed_status IN ('RAW','QUEUED','STORED','FAILED')"
)
# Create trigger function for vector_ids validation
op.execute('''
CREATE OR REPLACE FUNCTION trg_vector_ids_not_empty()
RETURNS TRIGGER LANGUAGE plpgsql AS $$
BEGIN
IF NEW.embed_status = 'STORED'
AND (NEW.vector_ids IS NULL OR array_length(NEW.vector_ids,1) = 0) THEN
RAISE EXCEPTION
USING MESSAGE = 'vector_ids must not be empty when embed_status = STORED';
END IF;
RETURN NEW;
END;
$$;
''')
# Create trigger
op.execute('''
CREATE TRIGGER check_vector_ids
BEFORE UPDATE ON source_item
FOR EACH ROW EXECUTE FUNCTION trg_vector_ids_not_empty();
''')
# Create indexes for source_item
op.create_index('source_modality_idx', 'source_item', ['modality'])
op.create_index('source_status_idx', 'source_item', ['embed_status'])
op.create_index('source_tags_idx', 'source_item', ['tags'], postgresql_using='gin')
# Create indexes for mail_message
op.create_index('mail_sent_idx', 'mail_message', ['sent_at'])
op.create_index('mail_recipients_idx', 'mail_message', ['recipients'], postgresql_using='gin')
op.create_index('mail_tsv_idx', 'mail_message', ['tsv'], postgresql_using='gin')
# Create index for chat_message
op.create_index('chat_channel_idx', 'chat_message', ['platform', 'channel_id'])
# Create indexes for git_commit
op.create_index('git_files_idx', 'git_commit', ['files_changed'], postgresql_using='gin')
op.create_index('git_date_idx', 'git_commit', ['author_date'])
# Create index for photo
op.create_index('photo_taken_idx', 'photo', ['exif_taken_at'])
# Create indexes for rss_feeds
op.create_index('rss_feeds_active_idx', 'rss_feeds', ['active', 'last_checked_at'])
op.create_index('rss_feeds_tags_idx', 'rss_feeds', ['tags'], postgresql_using='gin')
# Create indexes for email_accounts
op.create_index('email_accounts_address_idx', 'email_accounts', ['email_address'], unique=True)
op.create_index('email_accounts_active_idx', 'email_accounts', ['active', 'last_sync_at'])
op.create_index('email_accounts_tags_idx', 'email_accounts', ['tags'], postgresql_using='gin')
# Create indexes for github_item
op.create_index('gh_repo_kind_idx', 'github_item', ['repo_path', 'kind'])
op.create_index('gh_issue_lookup_idx', 'github_item', ['repo_path', 'kind', 'number'])
op.create_index('gh_labels_idx', 'github_item', ['labels'], postgresql_using='gin')
# Create add_tags helper function
op.execute('''
CREATE OR REPLACE FUNCTION add_tags(p_source BIGINT, p_tags TEXT[])
RETURNS VOID LANGUAGE SQL AS $$
UPDATE source_item
SET tags =
(SELECT ARRAY(SELECT DISTINCT unnest(tags || p_tags)))
WHERE id = p_source;
$$;
''')
def downgrade() -> None:
# Drop indexes
op.drop_index('gh_tsv_idx', table_name='github_item')
op.drop_index('gh_labels_idx', table_name='github_item')
op.drop_index('gh_issue_lookup_idx', table_name='github_item')
op.drop_index('gh_repo_kind_idx', table_name='github_item')
op.drop_index('email_accounts_tags_idx', table_name='email_accounts')
op.drop_index('email_accounts_active_idx', table_name='email_accounts')
op.drop_index('email_accounts_address_idx', table_name='email_accounts')
op.drop_index('rss_feeds_tags_idx', table_name='rss_feeds')
op.drop_index('rss_feeds_active_idx', table_name='rss_feeds')
op.drop_index('photo_taken_idx', table_name='photo')
op.drop_index('git_date_idx', table_name='git_commit')
op.drop_index('git_files_idx', table_name='git_commit')
op.drop_index('chat_channel_idx', table_name='chat_message')
op.drop_index('mail_tsv_idx', table_name='mail_message')
op.drop_index('mail_recipients_idx', table_name='mail_message')
op.drop_index('mail_sent_idx', table_name='mail_message')
op.drop_index('source_tags_idx', table_name='source_item')
op.drop_index('source_status_idx', table_name='source_item')
op.drop_index('source_modality_idx', table_name='source_item')
# Drop tables
op.drop_table("photo")
op.drop_table("misc_doc")
op.drop_table("mail_message")
op.drop_table("git_commit")
op.drop_table("chat_message")
op.drop_table("book_doc")
op.drop_table("blog_post")
op.drop_table("github_item")
op.drop_table("source_item")
op.drop_table("rss_feeds")
op.drop_table("email_accounts")
# Drop triggers and functions
op.execute("DROP TRIGGER IF EXISTS check_vector_ids ON source_item")
op.execute("DROP FUNCTION IF EXISTS trg_vector_ids_not_empty()")
op.execute("DROP FUNCTION IF EXISTS add_tags(BIGINT, TEXT[])")
# Drop enum type
op.execute("DROP TYPE IF EXISTS gh_item_kind")

View File

@ -1,229 +0,0 @@
/*========================================================================
Knowledge-Base schema first-run script
---------------------------------------------------------------
PostgreSQL 15+
Creates every table, index, trigger and helper in one pass
No ALTER statements or later migrations required
Enable pgcrypto for UUID helpers (safe to re-run)
========================================================================*/
-------------------------------------------------------------------------------
-- 0. EXTENSIONS
-------------------------------------------------------------------------------
CREATE EXTENSION IF NOT EXISTS pgcrypto; -- gen_random_uuid(), crypt()
-------------------------------------------------------------------------------
-- 1. CANONICAL ARTEFACT TABLE (everything points here)
-------------------------------------------------------------------------------
CREATE TABLE source_item (
id BIGSERIAL PRIMARY KEY,
modality TEXT NOT NULL, -- 'mail'|'chat'|...
sha256 BYTEA UNIQUE NOT NULL, -- 32-byte blob
inserted_at TIMESTAMPTZ DEFAULT NOW(),
tags TEXT[] NOT NULL DEFAULT '{}', -- flexible labels
lang TEXT, -- ISO-639-1 or NULL
model_hash TEXT, -- embedding model ver.
vector_ids TEXT[] NOT NULL DEFAULT '{}', -- 0-N Qdrant IDs
embed_status TEXT NOT NULL DEFAULT 'RAW'
CHECK (embed_status IN ('RAW','QUEUED','STORED','FAILED')),
byte_length INTEGER, -- original size
mime_type TEXT
);
CREATE INDEX source_modality_idx ON source_item (modality);
CREATE INDEX source_status_idx ON source_item (embed_status);
CREATE INDEX source_tags_idx ON source_item USING GIN (tags);
-- 1.a Trigger vector_ids must be present when status = STORED
CREATE OR REPLACE FUNCTION trg_vector_ids_not_empty()
RETURNS TRIGGER LANGUAGE plpgsql AS $$
BEGIN
IF NEW.embed_status = 'STORED'
AND (NEW.vector_ids IS NULL OR array_length(NEW.vector_ids,1) = 0) THEN
RAISE EXCEPTION
USING MESSAGE = 'vector_ids must not be empty when embed_status = STORED';
END IF;
RETURN NEW;
END;
$$;
CREATE TRIGGER check_vector_ids
BEFORE UPDATE ON source_item
FOR EACH ROW EXECUTE FUNCTION trg_vector_ids_not_empty();
-------------------------------------------------------------------------------
-- 2. MAIL MESSAGES
-------------------------------------------------------------------------------
CREATE TABLE mail_message (
id BIGSERIAL PRIMARY KEY,
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
message_id TEXT UNIQUE,
subject TEXT,
sender TEXT,
recipients TEXT[],
sent_at TIMESTAMPTZ,
body_raw TEXT,
attachments JSONB
);
CREATE INDEX mail_sent_idx ON mail_message (sent_at);
CREATE INDEX mail_recipients_idx ON mail_message USING GIN (recipients);
ALTER TABLE mail_message
ADD COLUMN tsv tsvector
GENERATED ALWAYS AS (
to_tsvector('english',
coalesce(subject,'') || ' ' || coalesce(body_raw,'')))
STORED;
CREATE INDEX mail_tsv_idx ON mail_message USING GIN (tsv);
-------------------------------------------------------------------------------
-- 3. CHAT (Slack / Discord)
-------------------------------------------------------------------------------
CREATE TABLE chat_message (
id BIGSERIAL PRIMARY KEY,
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
platform TEXT CHECK (platform IN ('slack','discord')),
channel_id TEXT,
author TEXT,
sent_at TIMESTAMPTZ,
body_raw TEXT
);
CREATE INDEX chat_channel_idx ON chat_message (platform, channel_id);
-------------------------------------------------------------------------------
-- 4. GIT COMMITS (local repos)
-------------------------------------------------------------------------------
CREATE TABLE git_commit (
id BIGSERIAL PRIMARY KEY,
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
repo_path TEXT,
commit_sha TEXT UNIQUE,
author_name TEXT,
author_email TEXT,
author_date TIMESTAMPTZ,
msg_raw TEXT,
diff_summary TEXT,
files_changed TEXT[]
);
CREATE INDEX git_files_idx ON git_commit USING GIN (files_changed);
CREATE INDEX git_date_idx ON git_commit (author_date);
-------------------------------------------------------------------------------
-- 5. PHOTOS
-------------------------------------------------------------------------------
CREATE TABLE photo (
id BIGSERIAL PRIMARY KEY,
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
file_path TEXT,
exif_taken_at TIMESTAMPTZ,
exif_lat NUMERIC(9,6),
exif_lon NUMERIC(9,6),
camera_make TEXT,
camera_model TEXT
);
CREATE INDEX photo_taken_idx ON photo (exif_taken_at);
-------------------------------------------------------------------------------
-- 6. BOOKS, BLOG POSTS, MISC DOCS
-------------------------------------------------------------------------------
CREATE TABLE book_doc (
id BIGSERIAL PRIMARY KEY,
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
title TEXT,
author TEXT,
chapter TEXT,
published DATE
);
CREATE TABLE blog_post (
id BIGSERIAL PRIMARY KEY,
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
url TEXT UNIQUE,
title TEXT,
published TIMESTAMPTZ
);
CREATE TABLE misc_doc (
id BIGSERIAL PRIMARY KEY,
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
path TEXT,
mime_type TEXT
);
-------------------------------------------------------------------------------
-- 6.5 RSS FEEDS
-------------------------------------------------------------------------------
CREATE TABLE rss_feeds (
id BIGSERIAL PRIMARY KEY,
url TEXT UNIQUE NOT NULL,
title TEXT,
description TEXT,
tags TEXT[] NOT NULL DEFAULT '{}',
last_checked_at TIMESTAMPTZ,
active BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX rss_feeds_active_idx ON rss_feeds (active, last_checked_at);
CREATE INDEX rss_feeds_tags_idx ON rss_feeds USING GIN (tags);
-------------------------------------------------------------------------------
-- 7. GITHUB ITEMS (issues, PRs, comments, project cards)
-------------------------------------------------------------------------------
CREATE TYPE gh_item_kind AS ENUM ('issue','pr','comment','project_card');
CREATE TABLE github_item (
id BIGSERIAL PRIMARY KEY,
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
kind gh_item_kind NOT NULL,
repo_path TEXT NOT NULL, -- "owner/repo"
number INTEGER, -- issue/PR number (NULL for commit comment)
parent_number INTEGER, -- comment → its issue/PR
commit_sha TEXT, -- for commit comments
state TEXT, -- 'open'|'closed'|'merged'
title TEXT,
body_raw TEXT,
labels TEXT[],
author TEXT,
created_at TIMESTAMPTZ,
closed_at TIMESTAMPTZ,
merged_at TIMESTAMPTZ,
diff_summary TEXT, -- PR only
payload JSONB -- extra GitHub fields
);
CREATE INDEX gh_repo_kind_idx ON github_item (repo_path, kind);
CREATE INDEX gh_issue_lookup_idx ON github_item (repo_path, kind, number);
CREATE INDEX gh_labels_idx ON github_item USING GIN (labels);
CREATE INDEX gh_tsv_idx ON github_item
WHERE kind IN ('issue','pr')
USING GIN (to_tsvector('english',
coalesce(title,'') || ' ' || coalesce(body_raw,'')));
-------------------------------------------------------------------------------
-- 8. HELPER FUNCTION add tags
-------------------------------------------------------------------------------
CREATE OR REPLACE FUNCTION add_tags(p_source BIGINT, p_tags TEXT[])
RETURNS VOID LANGUAGE SQL AS $$
UPDATE source_item
SET tags =
(SELECT ARRAY(SELECT DISTINCT unnest(tags || p_tags)))
WHERE id = p_source;
$$;
-------------------------------------------------------------------------------
-- 9. (optional) PARTITION STUBS create per-year partitions later
-------------------------------------------------------------------------------
/*
-- example:
CREATE TABLE mail_message_2026 PARTITION OF mail_message
FOR VALUES FROM ('2026-01-01') TO ('2027-01-01');
*/
-- =========================================================================
-- Schema creation complete
-- =========================================================================

59
dev.sh Executable file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env bash
set -eo pipefail
# Colors for output
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
NC='\033[0m' # No Color
echo -e "${GREEN}Starting development environment for Memory Knowledge Base...${NC}"
# Get the directory of the script
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
# Create a temporary docker-compose override file to expose PostgreSQL
echo -e "${YELLOW}Creating docker-compose override to expose PostgreSQL...${NC}"
if [ ! -f docker-compose.override.yml ]; then
cat > docker-compose.override.yml << EOL
version: "3.9"
services:
postgres:
ports:
- "5432:5432"
EOL
fi
# Start the containers
echo -e "${GREEN}Starting docker containers...${NC}"
docker-compose up -d postgres rabbitmq qdrant
# Wait for PostgreSQL to be ready
echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}"
for i in {1..30}; do
if docker-compose exec postgres pg_isready -U kb > /dev/null 2>&1; then
echo -e "${GREEN}PostgreSQL is ready!${NC}"
break
fi
echo -n "."
sleep 1
done
# Initialize the database if needed
echo -e "${YELLOW}Checking if database needs initialization...${NC}"
if ! docker-compose exec postgres psql -U kb -d kb -c "SELECT 1 FROM information_schema.tables WHERE table_name = 'source_item'" | grep -q 1; then
echo -e "${GREEN}Initializing database from schema.sql...${NC}"
docker-compose exec postgres psql -U kb -d kb -f /docker-entrypoint-initdb.d/schema.sql
else
echo -e "${GREEN}Database already initialized.${NC}"
fi
echo -e "${GREEN}Development environment is ready!${NC}"
echo -e "${YELLOW}PostgreSQL is available at localhost:5432${NC}"
echo -e "${YELLOW}Username: kb${NC}"
echo -e "${YELLOW}Password: (check secrets/postgres_password.txt)${NC}"
echo -e "${YELLOW}Database: kb${NC}"
echo ""
echo -e "${GREEN}To stop the environment, run:${NC}"
echo -e "${YELLOW}docker-compose down${NC}"

View File

@ -61,7 +61,6 @@ services:
secrets: [postgres_password] secrets: [postgres_password]
volumes: volumes:
- db_data:/var/lib/postgresql/data:rw - db_data:/var/lib/postgresql/data:rw
- ./db:/docker-entrypoint-initdb.d:ro
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U kb"] test: ["CMD-SHELL", "pg_isready -U kb"]
interval: 10s interval: 10s

View File

@ -1,3 +1,5 @@
sqlalchemy==2.0.30 sqlalchemy==2.0.30
psycopg2-binary==2.9.9 psycopg2-binary==2.9.9
pydantic==2.7.1 pydantic==2.7.1
alembic==1.13.1
dotenv==1.1.0

5
requirements-dev.txt Normal file
View File

@ -0,0 +1,5 @@
pytest==7.4.4
pytest-cov==4.1.0
black==23.12.1
mypy==1.8.0
isort==5.13.2

View File

@ -16,6 +16,7 @@ def read_requirements(filename: str) -> list[str]:
common_requires = read_requirements('requirements-common.txt') common_requires = read_requirements('requirements-common.txt')
api_requires = read_requirements('requirements-api.txt') api_requires = read_requirements('requirements-api.txt')
workers_requires = read_requirements('requirements-workers.txt') workers_requires = read_requirements('requirements-workers.txt')
dev_requires = read_requirements('requirements-dev.txt')
setup( setup(
name="memory", name="memory",
@ -27,5 +28,7 @@ setup(
"api": api_requires + common_requires, "api": api_requires + common_requires,
"workers": workers_requires + common_requires, "workers": workers_requires + common_requires,
"common": common_requires, "common": common_requires,
"dev": dev_requires,
"all": api_requires + workers_requires + common_requires + dev_requires,
}, },
) )

View File

@ -30,3 +30,8 @@ def get_scoped_session():
engine = get_engine() engine = get_engine()
session_factory = sessionmaker(bind=engine) session_factory = sessionmaker(bind=engine)
return scoped_session(session_factory) return scoped_session(session_factory)
def make_session():
with get_scoped_session() as session:
yield session

View File

@ -2,8 +2,8 @@
Database models for the knowledge base system. Database models for the knowledge base system.
""" """
from sqlalchemy import ( from sqlalchemy import (
Column, ForeignKey, Integer, BigInteger, Text, DateTime, Boolean, Float, Column, ForeignKey, Integer, BigInteger, Text, DateTime, Boolean,
ARRAY, func ARRAY, func, Numeric, CheckConstraint, Index
) )
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -27,6 +27,14 @@ class SourceItem(Base):
byte_length = Column(Integer) byte_length = Column(Integer)
mime_type = Column(Text) mime_type = Column(Text)
# Add table-level constraint and indexes
__table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
Index('source_modality_idx', 'modality'),
Index('source_status_idx', 'embed_status'),
Index('source_tags_idx', 'tags', postgresql_using='gin'),
)
class MailMessage(Base): class MailMessage(Base):
__tablename__ = 'mail_message' __tablename__ = 'mail_message'
@ -42,6 +50,13 @@ class MailMessage(Base):
attachments = Column(JSONB) attachments = Column(JSONB)
tsv = Column(TSVECTOR) tsv = Column(TSVECTOR)
# Add indexes
__table_args__ = (
Index('mail_sent_idx', 'sent_at'),
Index('mail_recipients_idx', 'recipients', postgresql_using='gin'),
Index('mail_tsv_idx', 'tsv', postgresql_using='gin'),
)
class ChatMessage(Base): class ChatMessage(Base):
__tablename__ = 'chat_message' __tablename__ = 'chat_message'
@ -54,6 +69,11 @@ class ChatMessage(Base):
sent_at = Column(DateTime(timezone=True)) sent_at = Column(DateTime(timezone=True))
body_raw = Column(Text) body_raw = Column(Text)
# Add index
__table_args__ = (
Index('chat_channel_idx', 'platform', 'channel_id'),
)
class GitCommit(Base): class GitCommit(Base):
__tablename__ = 'git_commit' __tablename__ = 'git_commit'
@ -69,6 +89,12 @@ class GitCommit(Base):
diff_summary = Column(Text) diff_summary = Column(Text)
files_changed = Column(ARRAY(Text)) files_changed = Column(ARRAY(Text))
# Add indexes
__table_args__ = (
Index('git_files_idx', 'files_changed', postgresql_using='gin'),
Index('git_date_idx', 'author_date'),
)
class Photo(Base): class Photo(Base):
__tablename__ = 'photo' __tablename__ = 'photo'
@ -77,10 +103,14 @@ class Photo(Base):
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False) source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
file_path = Column(Text) file_path = Column(Text)
exif_taken_at = Column(DateTime(timezone=True)) exif_taken_at = Column(DateTime(timezone=True))
exif_lat = Column(Float) exif_lat = Column(Numeric(9, 6))
exif_lon = Column(Float) exif_lon = Column(Numeric(9, 6))
camera_make = Column(Text) camera = Column(Text)
camera_model = Column(Text)
# Add index
__table_args__ = (
Index('photo_taken_idx', 'exif_taken_at'),
)
class BookDoc(Base): class BookDoc(Base):
@ -91,7 +121,7 @@ class BookDoc(Base):
title = Column(Text) title = Column(Text)
author = Column(Text) author = Column(Text)
chapter = Column(Text) chapter = Column(Text)
published = Column(DateTime) published = Column(DateTime(timezone=True))
class BlogPost(Base): class BlogPost(Base):
@ -125,3 +155,66 @@ class RssFeed(Base):
active = Column(Boolean, nullable=False, server_default='true') active = Column(Boolean, nullable=False, server_default='true')
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
# Add indexes
__table_args__ = (
Index('rss_feeds_active_idx', 'active', 'last_checked_at'),
Index('rss_feeds_tags_idx', 'tags', postgresql_using='gin'),
)
class EmailAccount(Base):
__tablename__ = 'email_accounts'
id = Column(BigInteger, primary_key=True)
name = Column(Text, nullable=False)
email_address = Column(Text, nullable=False, unique=True)
imap_server = Column(Text, nullable=False)
imap_port = Column(Integer, nullable=False, server_default='993')
username = Column(Text, nullable=False)
password = Column(Text, nullable=False)
use_ssl = Column(Boolean, nullable=False, server_default='true')
folders = Column(ARRAY(Text), nullable=False, server_default='{}')
tags = Column(ARRAY(Text), nullable=False, server_default='{}')
last_sync_at = Column(DateTime(timezone=True))
active = Column(Boolean, nullable=False, server_default='true')
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
# Add indexes
__table_args__ = (
Index('email_accounts_address_idx', 'email_address', unique=True),
Index('email_accounts_active_idx', 'active', 'last_sync_at'),
Index('email_accounts_tags_idx', 'tags', postgresql_using='gin'),
)
class GithubItem(Base):
__tablename__ = 'github_item'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
kind = Column(Text, nullable=False)
repo_path = Column(Text, nullable=False)
number = Column(Integer)
parent_number = Column(Integer)
commit_sha = Column(Text)
state = Column(Text)
title = Column(Text)
body_raw = Column(Text)
labels = Column(ARRAY(Text))
author = Column(Text)
created_at = Column(DateTime(timezone=True))
closed_at = Column(DateTime(timezone=True))
merged_at = Column(DateTime(timezone=True))
diff_summary = Column(Text)
payload = Column(JSONB)
__table_args__ = (
CheckConstraint("kind IN ('issue', 'pr', 'comment', 'project_card')"),
Index('gh_repo_kind_idx', 'repo_path', 'kind'),
Index('gh_issue_lookup_idx', 'repo_path', 'kind', 'number'),
Index('gh_labels_idx', 'labels', postgresql_using='gin'),
)

View File

@ -0,0 +1,16 @@
import os
from dotenv import load_dotenv
load_dotenv()
DB_USER = os.getenv("DB_USER", "kb")
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
DB_HOST = os.getenv("DB_HOST", "postgres")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME", "kb")
def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME):
return f"postgresql://{user}:{password}@{host}:{port}/{db}"
DB_URL = os.getenv("DATABASE_URL", make_db_url())

362
src/memory/workers/email.py Normal file
View File

@ -0,0 +1,362 @@
import email
import hashlib
import imaplib
import logging
import re
from contextlib import contextmanager
from datetime import datetime
from email.utils import parsedate_to_datetime
from sqlalchemy.orm import Session
from memory.common.db.models import EmailAccount, MailMessage, SourceItem
logger = logging.getLogger(__name__)
def extract_recipients(msg: email.message.Message) -> list[str]:
"""
Extract email recipients from message headers.
Args:
msg: Email message object
Returns:
List of recipient email addresses
"""
return [
recipient
for field in ["To", "Cc", "Bcc"]
if (field_value := msg.get(field, ""))
for r in field_value.split(",")
if (recipient := r.strip())
]
def extract_date(msg: email.message.Message) -> datetime | None:
"""
Parse date from email header.
Args:
msg: Email message object
Returns:
Parsed datetime or None if parsing failed
"""
if date_str := msg.get("Date"):
try:
return parsedate_to_datetime(date_str)
except Exception:
logger.warning(f"Could not parse date: {date_str}")
return None
def extract_body(msg: email.message.Message) -> str:
"""
Extract plain text body from email message.
Args:
msg: Email message object
Returns:
Plain text body content
"""
body = ""
if not msg.is_multipart():
try:
return msg.get_payload(decode=True).decode(errors='replace')
except Exception as e:
logger.error(f"Error decoding message body: {str(e)}")
return ""
for part in msg.walk():
content_type = part.get_content_type()
content_disposition = str(part.get("Content-Disposition", ""))
if content_type == "text/plain" and "attachment" not in content_disposition:
try:
body += part.get_payload(decode=True).decode(errors='replace') + "\n"
except Exception as e:
logger.error(f"Error decoding message part: {str(e)}")
return body
def extract_attachments(msg: email.message.Message) -> list[dict]:
"""
Extract attachment metadata from email.
Args:
msg: Email message object
Returns:
List of attachment metadata dicts
"""
if not msg.is_multipart():
return []
attachments = []
for part in msg.walk():
content_disposition = part.get("Content-Disposition", "")
if "attachment" not in content_disposition:
continue
if filename := part.get_filename():
attachments.append({
"filename": filename,
"content_type": part.get_content_type(),
"size": len(part.get_payload(decode=True))
})
return attachments
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
"""
Compute a SHA-256 hash of message content.
Args:
msg_id: Message ID
subject: Email subject
sender: Sender email
body: Message body
Returns:
SHA-256 hash as bytes
"""
hash_content = (msg_id + subject + sender + body).encode()
return hashlib.sha256(hash_content).digest()
def parse_email_message(raw_email: str) -> dict:
"""
Parse raw email into structured data.
Args:
raw_email: Raw email content as string
Returns:
Dict with parsed email data
"""
msg = email.message_from_string(raw_email)
return {
"message_id": msg.get("Message-ID", ""),
"subject": msg.get("Subject", ""),
"sender": msg.get("From", ""),
"recipients": extract_recipients(msg),
"sent_at": extract_date(msg),
"body": extract_body(msg),
"attachments": extract_attachments(msg)
}
def create_source_item(
db_session: Session,
message_hash: bytes,
account_tags: list[str],
raw_email_size: int,
) -> SourceItem:
"""
Create a new source item record.
Args:
db_session: Database session
message_hash: SHA-256 hash of message
account_tags: Tags from the email account
raw_email_size: Size of raw email in bytes
Returns:
Newly created SourceItem
"""
source_item = SourceItem(
modality="mail",
sha256=message_hash,
tags=account_tags,
byte_length=raw_email_size,
mime_type="message/rfc822",
embed_status="RAW"
)
db_session.add(source_item)
db_session.flush()
return source_item
def create_mail_message(
db_session: Session,
source_id: int,
parsed_email: dict,
folder: str,
) -> MailMessage:
"""
Create a new mail message record.
Args:
db_session: Database session
source_id: ID of the SourceItem
parsed_email: Parsed email data
folder: IMAP folder name
Returns:
Newly created MailMessage
"""
mail_message = MailMessage(
source_id=source_id,
message_id=parsed_email["message_id"],
subject=parsed_email["subject"],
sender=parsed_email["sender"],
recipients=parsed_email["recipients"],
sent_at=parsed_email["sent_at"],
body_raw=parsed_email["body"],
attachments={"items": parsed_email["attachments"], "folder": folder}
)
db_session.add(mail_message)
return mail_message
def check_message_exists(db_session: Session, message_id: str, message_hash: bytes) -> bool:
"""
Check if a message already exists in the database.
Args:
db_session: Database session
message_id: Email message ID
message_hash: SHA-256 hash of message
Returns:
True if message exists, False otherwise
"""
return (
# Check by message_id first (faster)
message_id and db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first()
# Then check by message_hash
or db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first() is not None
)
def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
"""
Extract the UID and raw email data from the message data.
"""
uid_pattern = re.compile(r'UID (\d+)')
uid_match = uid_pattern.search(msg_data[0][0].decode('utf-8', errors='replace'))
uid = uid_match.group(1) if uid_match else None
raw_email = msg_data[0][1]
return uid, raw_email
def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> tuple[str, bytes] | None:
try:
status, msg_data = conn.fetch(uid, '(UID RFC822)')
if status != 'OK' or not msg_data or not msg_data[0]:
logger.error(f"Error fetching message {uid}")
return None
return extract_email_uid(msg_data)
except Exception as e:
logger.error(f"Error processing message {uid}: {str(e)}")
return None
def fetch_email_since(
conn: imaplib.IMAP4_SSL,
folder: str,
since_date: datetime
) -> list[tuple[str, bytes]]:
"""
Fetch emails from a folder since a given date.
Args:
conn: IMAP connection
folder: Folder name to select
since_date: Fetch emails since this date
Returns:
List of tuples with (uid, raw_email)
"""
try:
status, counts = conn.select(folder)
if status != 'OK':
logger.error(f"Error selecting folder {folder}: {counts}")
return []
date_str = since_date.strftime("%d-%b-%Y")
status, data = conn.search(None, f'(SINCE "{date_str}")')
if status != 'OK':
logger.error(f"Error searching folder {folder}: {data}")
return []
except Exception as e:
logger.error(f"Error in fetch_email_since for folder {folder}: {str(e)}")
return []
if not data or not data[0]:
return []
return [email for uid in data[0].split() if (email := fetch_email(conn, uid))]
def process_folder(
conn: imaplib.IMAP4_SSL,
folder: str,
account: EmailAccount,
since_date: datetime
) -> dict:
"""
Process a single folder from an email account.
Args:
conn: Active IMAP connection
folder: Folder name to process
account: Email account configuration
since_date: Only fetch messages newer than this date
Returns:
Stats dictionary for the folder
"""
new_messages, errors = 0, 0
try:
emails = fetch_email_since(conn, folder, since_date)
for uid, raw_email in emails:
try:
task = process_message.delay(
account_id=account.id,
message_id=uid,
folder=folder,
raw_email=raw_email.decode('utf-8', errors='replace')
)
if task:
new_messages += 1
except Exception as e:
logger.error(f"Error queuing message {uid}: {str(e)}")
errors += 1
except Exception as e:
logger.error(f"Error processing folder {folder}: {str(e)}")
errors += 1
return {
"messages_found": len(emails),
"new_messages": new_messages,
"errors": errors
}
@contextmanager
def imap_connection(account: EmailAccount) -> imaplib.IMAP4_SSL:
conn = imaplib.IMAP4_SSL(
host=account.imap_server,
port=account.imap_port
)
try:
conn.login(account.username, account.password)
yield conn
finally:
# Always try to logout and close the connection
try:
conn.logout()
except Exception as e:
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")

View File

@ -1,4 +1,4 @@
""" """
Import sub-modules so Celery can register their @app.task decorators. Import sub-modules so Celery can register their @app.task decorators.
""" """
from memory.workers.tasks import text, photo, ocr, git, rss, docs # noqa from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa

View File

@ -0,0 +1,137 @@
import logging
from datetime import datetime
from memory.common.db.connection import make_session
from memory.common.db.models import EmailAccount
from memory.workers.celery_app import app
from memory.workers.email import (
check_message_exists,
compute_message_hash,
create_mail_message,
create_source_item,
imap_connection,
parse_email_message,
process_folder,
)
logger = logging.getLogger(__name__)
@app.task(name="memory.email.process_message")
def process_message(
account_id: int, message_id: str, folder: str, raw_email: str,
) -> int | None:
"""
Process a single email message and store it in the database.
Args:
account_id: ID of the EmailAccount
message_id: UID of the message on the server
folder: Folder name where the message is stored
raw_email: Raw email content as string
Returns:
source_id if successful, None otherwise
"""
with make_session() as db:
account = db.query(EmailAccount).get(account_id)
if not account:
logger.error(f"Account {account_id} not found")
return None
parsed_email = parse_email_message(raw_email)
# Use server-provided message ID if missing
if not parsed_email["message_id"]:
parsed_email["message_id"] = f"generated-{message_id}"
message_hash = compute_message_hash(
parsed_email["message_id"],
parsed_email["subject"],
parsed_email["sender"],
parsed_email["body"]
)
if check_message_exists(db, parsed_email["message_id"], message_hash):
logger.debug(f"Message {parsed_email['message_id']} already exists in database")
return None
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
create_mail_message(db, source_item.id, parsed_email, folder)
db.commit()
# TODO: Queue for embedding once that's implemented
return source_item.id
@app.task(name="memory.email.sync_account")
def sync_account(account_id: int) -> dict:
"""
Synchronize emails from a specific account.
Args:
account_id: ID of the EmailAccount to sync
Returns:
dict with stats about the sync operation
"""
with make_session() as db:
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
if not account or not account.active:
logger.warning(f"Account {account_id} not found or inactive")
return {"error": "Account not found or inactive"}
folders_to_process = account.folders or ["INBOX"]
since_date = account.last_sync_at or datetime(1970, 1, 1)
messages_found = 0
new_messages = 0
errors = 0
try:
with imap_connection(account) as conn:
for folder in folders_to_process:
folder_stats = process_folder(conn, folder, account, since_date)
messages_found += folder_stats["messages_found"]
new_messages += folder_stats["new_messages"]
errors += folder_stats["errors"]
account.last_sync_at = datetime.now()
db.commit()
except Exception as e:
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
return {"error": str(e)}
return {
"account": account.email_address,
"folders_processed": len(folders_to_process),
"messages_found": messages_found,
"new_messages": new_messages,
"errors": errors
}
@app.task(name="memory.email.sync_all_accounts")
def sync_all_accounts() -> list[dict]:
"""
Synchronize all active email accounts.
Returns:
List of task IDs that were scheduled
"""
with make_session() as db:
active_accounts = db.query(EmailAccount).filter(EmailAccount.active).all()
return [
{
"account_id": account.id,
"email": account.email_address,
"task_id": sync_account.delay(account.id).id
}
for account in active_accounts
]

View File

@ -0,0 +1,129 @@
import os
import subprocess
import uuid
from pathlib import Path
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from memory.common import settings
def get_test_db_name() -> str:
return f"test_db_{uuid.uuid4().hex[:8]}"
def create_test_database(test_db_name: str) -> str:
"""
Create a test database with a unique name.
Args:
test_db_name: Name for the test database
Returns:
URL to the test database
"""
admin_engine = create_engine(settings.DB_URL)
# Create a new database
with admin_engine.connect() as conn:
conn.execute(text("COMMIT")) # Close any open transaction
conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}"))
conn.execute(text(f"CREATE DATABASE {test_db_name}"))
admin_engine.dispose()
return settings.make_db_url(db=test_db_name)
def drop_test_database(test_db_name: str) -> None:
"""
Drop the test database.
Args:
test_db_name: Name of the test database to drop
"""
admin_engine = create_engine(settings.DB_URL)
with admin_engine.connect() as conn:
conn.execute(text("COMMIT")) # Close any open transaction
conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}"))
def run_alembic_migrations(db_name: str) -> None:
"""Run all Alembic migrations on the test database."""
project_root = Path(__file__).parent.parent.parent.parent.parent
alembic_ini = project_root / "db" / "migrations" / "alembic.ini"
breakpoint()
subprocess.run(
["alembic", "-c", str(alembic_ini), "upgrade", "head"],
env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)},
check=True,
capture_output=True,
)
@pytest.fixture
def test_db():
"""
Create a test database, run migrations, and clean up afterwards.
Returns:
The URL to the test database
"""
test_db_name = get_test_db_name()
# Create test database
test_db_url = create_test_database(test_db_name)
try:
run_alembic_migrations(test_db_name)
# Return the URL to the test database
yield test_db_url
finally:
# Clean up - drop the test database
drop_test_database(test_db_name)
@pytest.fixture
def db_engine(test_db):
"""
Create a SQLAlchemy engine connected to the test database.
Args:
test_db: URL to the test database (from the test_db fixture)
Returns:
SQLAlchemy engine
"""
engine = create_engine(test_db)
yield engine
engine.dispose()
@pytest.fixture
def db_session(db_engine):
"""
Create a new database session for a test.
Args:
db_engine: SQLAlchemy engine (from the db_engine fixture)
Returns:
SQLAlchemy session
"""
# Create a new sessionmaker
SessionLocal = sessionmaker(bind=db_engine, autocommit=False, autoflush=False)
# Create a new session
session = SessionLocal()
try:
yield session
finally:
# Close and rollback the session after the test is done
session.rollback()
session.close()

View File

@ -0,0 +1,325 @@
import email
import email.mime.multipart
import email.mime.text
import email.mime.base
from datetime import datetime
from email.utils import formatdate
from unittest.mock import ANY
import pytest
from memory.common.db.models import SourceItem
from memory.workers.email import (
compute_message_hash,
create_source_item,
extract_attachments,
extract_body,
extract_date,
extract_email_uid,
extract_recipients,
parse_email_message,
)
# Use a simple counter to generate unique message IDs without calling make_msgid
_msg_id_counter = 0
def _generate_test_message_id():
"""Generate a simple message ID for testing without expensive calls"""
global _msg_id_counter
_msg_id_counter += 1
return f"<test-message-{_msg_id_counter}@example.com>"
def create_email_message(
subject="Test Subject",
from_addr="sender@example.com",
to_addrs="recipient@example.com",
cc_addrs=None,
bcc_addrs=None,
date=None,
body="Test body content",
attachments=None,
multipart=True,
message_id=None,
):
"""Helper function to create email.message.Message objects for testing"""
if multipart:
msg = email.mime.multipart.MIMEMultipart()
msg.attach(email.mime.text.MIMEText(body))
if attachments:
for attachment in attachments:
attachment_part = email.mime.base.MIMEBase("application", "octet-stream")
attachment_part.set_payload(attachment["content"])
attachment_part.add_header(
"Content-Disposition",
f"attachment; filename={attachment['filename']}"
)
msg.attach(attachment_part)
else:
msg = email.mime.text.MIMEText(body)
msg["Subject"] = subject
msg["From"] = from_addr
msg["To"] = to_addrs
if cc_addrs:
msg["Cc"] = cc_addrs
if bcc_addrs:
msg["Bcc"] = bcc_addrs
if date:
msg["Date"] = formatdate(float(date.timestamp()))
if message_id:
msg["Message-ID"] = message_id
else:
msg["Message-ID"] = _generate_test_message_id()
return msg
@pytest.mark.parametrize(
"to_addr, cc_addr, bcc_addr, expected",
[
# Single recipient in To field
(
"recipient@example.com",
None,
None,
["recipient@example.com"]
),
# Multiple recipients in To field
(
"recipient1@example.com, recipient2@example.com",
None,
None,
["recipient1@example.com", "recipient2@example.com"]
),
# To, Cc fields
(
"recipient@example.com",
"cc@example.com",
None,
["recipient@example.com", "cc@example.com"]
),
# To, Cc, Bcc fields
(
"recipient@example.com",
"cc@example.com",
"bcc@example.com",
["recipient@example.com", "cc@example.com", "bcc@example.com"]
),
# Empty fields
(
"",
"",
"",
[]
),
]
)
def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected):
msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr)
assert sorted(extract_recipients(msg)) == sorted(expected)
def test_extract_date_missing():
msg = create_email_message(date=None)
assert extract_date(msg) is None
@pytest.mark.parametrize(
"date_str",
[
"Invalid Date Format",
"2023-01-01", # ISO format but not RFC compliant
"Monday, Jan 1, 2023", # Descriptive but not RFC compliant
"01/01/2023", # Common format but not RFC compliant
"", # Empty string
]
)
def test_extract_date_invalid_formats(date_str):
msg = create_email_message()
msg["Date"] = date_str
assert extract_date(msg) is None
@pytest.mark.parametrize(
"date_str",
[
"Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format
"01 Jan 2023 12:00:00 +0000", # RFC 822 format
"Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name
]
)
def test_extract_date(date_str):
msg = create_email_message()
msg["Date"] = date_str
result = extract_date(msg)
assert result is not None
assert result.year == 2023
assert result.month == 1
assert result.day == 1
@pytest.mark.parametrize('multipart', [True, False])
def test_extract_body_text_plain(multipart):
body_content = "This is a test email body"
msg = create_email_message(body=body_content, multipart=multipart)
extracted = extract_body(msg)
# Strip newlines for comparison since multipart emails often add them
assert extracted.strip() == body_content.strip()
def test_extract_body_with_attachments():
body_content = "This is a test email body"
attachments = [
{"filename": "test.txt", "content": b"attachment content"}
]
msg = create_email_message(body=body_content, attachments=attachments)
assert body_content in extract_body(msg)
def test_extract_attachments_none():
msg = create_email_message(multipart=True)
assert extract_attachments(msg) == []
def test_extract_attachments_with_files():
attachments = [
{"filename": "test1.txt", "content": b"content1"},
{"filename": "test2.pdf", "content": b"content2"}
]
msg = create_email_message(attachments=attachments)
result = extract_attachments(msg)
assert len(result) == 2
assert result[0]["filename"] == "test1.txt"
assert result[1]["filename"] == "test2.pdf"
def test_extract_attachments_non_multipart():
msg = create_email_message(multipart=False)
assert extract_attachments(msg) == []
@pytest.mark.parametrize(
"msg_id, subject, sender, body, expected",
[
(
"<test@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xf2\xbd" # First two bytes of the actual hash
),
(
"<different@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xa4\x15" # Will be different from the first hash
),
]
)
def test_compute_message_hash(msg_id, subject, sender, body, expected):
result = compute_message_hash(msg_id, subject, sender, body)
# Verify it's bytes and correct length for SHA-256 (32 bytes)
assert isinstance(result, bytes)
assert len(result) == 32
# Verify first two bytes match expected
assert result[:2] == expected
def test_hash_consistency():
args = ("<test@example.com>", "Test Subject", "sender@example.com", "Test body")
assert compute_message_hash(*args) == compute_message_hash(*args)
def test_parse_simple_email():
test_date = datetime(2023, 1, 1, 12, 0, 0)
msg_id = "<test123@example.com>"
msg = create_email_message(
subject="Test Subject",
from_addr="sender@example.com",
to_addrs="recipient@example.com",
date=test_date,
body="Test body content",
message_id=msg_id
)
result = parse_email_message(msg.as_string())
assert result == {
"message_id": msg_id,
"subject": "Test Subject",
"sender": "sender@example.com",
"recipients": ["recipient@example.com"],
"body": "Test body content\n",
"attachments": [],
"sent_at": ANY,
}
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
def test_parse_email_with_attachments():
attachments = [
{"filename": "test.txt", "content": b"attachment content"}
]
msg = create_email_message(attachments=attachments)
result = parse_email_message(msg.as_string())
assert len(result["attachments"]) == 1
assert result["attachments"][0]["filename"] == "test.txt"
def test_extract_email_uid_valid():
msg_data = [(b'1 (UID 12345 RFC822 {1234}', b'raw email content')]
uid, raw_email = extract_email_uid(msg_data)
assert uid == "12345"
assert raw_email == b'raw email content'
def test_extract_email_uid_no_match():
msg_data = [(b'1 (RFC822 {1234}', b'raw email content')]
uid, raw_email = extract_email_uid(msg_data)
assert uid is None
assert raw_email == b'raw email content'
def test_create_source_item(db_session):
# Mock data
message_hash = b'test_hash_bytes' + bytes(28) # 32 bytes for SHA-256
account_tags = ["work", "important"]
raw_email_size = 1024
# Call function
source_item = create_source_item(
db_session=db_session,
message_hash=message_hash,
account_tags=account_tags,
raw_email_size=raw_email_size
)
# Verify the source item was created correctly
assert isinstance(source_item, SourceItem)
assert source_item.id is not None
assert source_item.modality == "mail"
assert source_item.sha256 == message_hash
assert source_item.tags == account_tags
assert source_item.byte_length == raw_email_size
assert source_item.mime_type == "message/rfc822"
assert source_item.embed_status == "RAW"
# Verify it was added to the session
db_session.flush()
fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one()
assert fetched_item is not None
assert fetched_item.sha256 == message_hash