diff --git a/.gitignore b/.gitignore index 28c5b36..4af5d5d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,15 @@ .DS_Store secrets/ acme.json -__pycache__/ \ No newline at end of file +__pycache__/ +*.egg-info/ +*.pyc +*.pyo +*.pyd +*.pyw +*.pyz +*.pywz +*.pyzw + +docker-compose.override.yml +docker/pgadmin \ No newline at end of file diff --git a/db/migrations/alembic.ini b/db/migrations/alembic.ini new file mode 100644 index 0000000..2f7b489 --- /dev/null +++ b/db/migrations/alembic.ini @@ -0,0 +1,107 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = db/migrations + +# template used to generate migration files +file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d%%(second).2d_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +version_locations = db/migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# URL of the database to connect to +sqlalchemy.url = driver://user:pass@localhost/dbname + +# enable compare_type (to detect column type changes) +# and compare_server_default (to detect changed in server default values) +compare_type = true +compare_server_default = true + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S \ No newline at end of file diff --git a/db/migrations/env.py b/db/migrations/env.py new file mode 100644 index 0000000..796789b --- /dev/null +++ b/db/migrations/env.py @@ -0,0 +1,82 @@ +""" +Alembic environment configuration. +""" +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +from memory.common import settings +from memory.common.db.models import Base + + +# this is the Alembic Config object +config = context.config + +# setup database URL from environment variables +config.set_main_option("sqlalchemy.url", settings.DB_URL) + +# Interpret the config file for Python logging +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") + + +def run_migrations_offline() -> None: + """ + Run migrations in 'offline' mode - creates SQL scripts. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """ + Run migrations in 'online' mode - directly to the database. + + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() \ No newline at end of file diff --git a/db/migrations/script.py.mako b/db/migrations/script.py.mako new file mode 100644 index 0000000..46200be --- /dev/null +++ b/db/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} \ No newline at end of file diff --git a/db/migrations/versions/20250427_171537_initial_structure.py b/db/migrations/versions/20250427_171537_initial_structure.py new file mode 100644 index 0000000..5456416 --- /dev/null +++ b/db/migrations/versions/20250427_171537_initial_structure.py @@ -0,0 +1,341 @@ +"""Initial structure + +Revision ID: a466a07360d5 +Revises: +Create Date: 2025-04-27 17:15:37.487616 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "a466a07360d5" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute('CREATE EXTENSION IF NOT EXISTS pgcrypto') + + # Create enum type for github_item with IF NOT EXISTS + op.execute("DO $$ BEGIN CREATE TYPE gh_item_kind AS ENUM ('issue','pr','comment','project_card'); EXCEPTION WHEN duplicate_object THEN NULL; END $$;") + + op.create_table( + "email_accounts", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("email_address", sa.Text(), nullable=False), + sa.Column("imap_server", sa.Text(), nullable=False), + sa.Column("imap_port", sa.Integer(), server_default="993", nullable=False), + sa.Column("username", sa.Text(), nullable=False), + sa.Column("password", sa.Text(), nullable=False), + sa.Column("use_ssl", sa.Boolean(), server_default="true", nullable=False), + sa.Column("folders", sa.ARRAY(sa.Text()), server_default="{}", nullable=False), + sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False), + sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("active", sa.Boolean(), server_default="true", nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("email_address"), + ) + op.create_table( + "rss_feeds", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("url", sa.Text(), nullable=False), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False), + sa.Column("last_checked_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("active", sa.Boolean(), server_default="true", nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("url"), + ) + op.create_table( + "source_item", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("modality", sa.Text(), nullable=False), + sa.Column("sha256", postgresql.BYTEA(), nullable=False), + sa.Column( + "inserted_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False), + sa.Column("lang", sa.Text(), nullable=True), + sa.Column("model_hash", sa.Text(), nullable=True), + sa.Column( + "vector_ids", sa.ARRAY(sa.Text()), server_default="{}", nullable=False + ), + sa.Column("embed_status", sa.Text(), server_default="RAW", nullable=False), + sa.Column("byte_length", sa.Integer(), nullable=True), + sa.Column("mime_type", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("sha256"), + ) + op.create_table( + "blog_post", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), + sa.Column("url", sa.Text(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("published", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("url"), + ) + op.create_table( + "book_doc", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("author", sa.Text(), nullable=True), + sa.Column("chapter", sa.Text(), nullable=True), + sa.Column("published", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "chat_message", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), + sa.Column("platform", sa.Text(), nullable=True), + sa.Column("channel_id", sa.Text(), nullable=True), + sa.Column("author", sa.Text(), nullable=True), + sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("body_raw", sa.Text(), nullable=True), + sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "git_commit", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), + sa.Column("repo_path", sa.Text(), nullable=True), + sa.Column("commit_sha", sa.Text(), nullable=True), + sa.Column("author_name", sa.Text(), nullable=True), + sa.Column("author_email", sa.Text(), nullable=True), + sa.Column("author_date", sa.DateTime(timezone=True), nullable=True), + sa.Column("msg_raw", sa.Text(), nullable=True), + sa.Column("diff_summary", sa.Text(), nullable=True), + sa.Column("files_changed", sa.ARRAY(sa.Text()), nullable=True), + sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("commit_sha"), + ) + op.create_table( + "mail_message", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), + sa.Column("message_id", sa.Text(), nullable=True), + sa.Column("subject", sa.Text(), nullable=True), + sa.Column("sender", sa.Text(), nullable=True), + sa.Column("recipients", sa.ARRAY(sa.Text()), nullable=True), + sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("body_raw", sa.Text(), nullable=True), + sa.Column( + "attachments", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column("tsv", postgresql.TSVECTOR(), nullable=True), + sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("message_id"), + ) + op.create_table( + "misc_doc", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), + sa.Column("path", sa.Text(), nullable=True), + sa.Column("mime_type", sa.Text(), nullable=True), + sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "photo", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), + sa.Column("file_path", sa.Text(), nullable=True), + sa.Column("exif_taken_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("exif_lat", sa.Numeric(9, 6), nullable=True), + sa.Column("exif_lon", sa.Numeric(9, 6), nullable=True), + sa.Column("camera", sa.Text(), nullable=True), + sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + + # Add github_item table + op.create_table( + "github_item", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), + sa.Column("kind", sa.Text(), nullable=False), + sa.Column("repo_path", sa.Text(), nullable=False), + sa.Column("number", sa.Integer(), nullable=True), + sa.Column("parent_number", sa.Integer(), nullable=True), + sa.Column("commit_sha", sa.Text(), nullable=True), + sa.Column("state", sa.Text(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("body_raw", sa.Text(), nullable=True), + sa.Column("labels", sa.ARRAY(sa.Text()), nullable=True), + sa.Column("author", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("closed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("merged_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("diff_summary", sa.Text(), nullable=True), + sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + + # Add constraint to github_item.kind + op.create_check_constraint( + "github_item_kind_check", + "github_item", + "kind IN ('issue', 'pr', 'comment', 'project_card')" + ) + + # Add missing constraint to source_item + op.create_check_constraint( + "source_item_embed_status_check", + "source_item", + "embed_status IN ('RAW','QUEUED','STORED','FAILED')" + ) + + # Create trigger function for vector_ids validation + op.execute(''' + CREATE OR REPLACE FUNCTION trg_vector_ids_not_empty() + RETURNS TRIGGER LANGUAGE plpgsql AS $$ + BEGIN + IF NEW.embed_status = 'STORED' + AND (NEW.vector_ids IS NULL OR array_length(NEW.vector_ids,1) = 0) THEN + RAISE EXCEPTION + USING MESSAGE = 'vector_ids must not be empty when embed_status = STORED'; + END IF; + RETURN NEW; + END; + $$; + ''') + + # Create trigger + op.execute(''' + CREATE TRIGGER check_vector_ids + BEFORE UPDATE ON source_item + FOR EACH ROW EXECUTE FUNCTION trg_vector_ids_not_empty(); + ''') + + # Create indexes for source_item + op.create_index('source_modality_idx', 'source_item', ['modality']) + op.create_index('source_status_idx', 'source_item', ['embed_status']) + op.create_index('source_tags_idx', 'source_item', ['tags'], postgresql_using='gin') + + # Create indexes for mail_message + op.create_index('mail_sent_idx', 'mail_message', ['sent_at']) + op.create_index('mail_recipients_idx', 'mail_message', ['recipients'], postgresql_using='gin') + op.create_index('mail_tsv_idx', 'mail_message', ['tsv'], postgresql_using='gin') + + # Create index for chat_message + op.create_index('chat_channel_idx', 'chat_message', ['platform', 'channel_id']) + + # Create indexes for git_commit + op.create_index('git_files_idx', 'git_commit', ['files_changed'], postgresql_using='gin') + op.create_index('git_date_idx', 'git_commit', ['author_date']) + + # Create index for photo + op.create_index('photo_taken_idx', 'photo', ['exif_taken_at']) + + # Create indexes for rss_feeds + op.create_index('rss_feeds_active_idx', 'rss_feeds', ['active', 'last_checked_at']) + op.create_index('rss_feeds_tags_idx', 'rss_feeds', ['tags'], postgresql_using='gin') + + # Create indexes for email_accounts + op.create_index('email_accounts_address_idx', 'email_accounts', ['email_address'], unique=True) + op.create_index('email_accounts_active_idx', 'email_accounts', ['active', 'last_sync_at']) + op.create_index('email_accounts_tags_idx', 'email_accounts', ['tags'], postgresql_using='gin') + + # Create indexes for github_item + op.create_index('gh_repo_kind_idx', 'github_item', ['repo_path', 'kind']) + op.create_index('gh_issue_lookup_idx', 'github_item', ['repo_path', 'kind', 'number']) + op.create_index('gh_labels_idx', 'github_item', ['labels'], postgresql_using='gin') + + # Create add_tags helper function + op.execute(''' + CREATE OR REPLACE FUNCTION add_tags(p_source BIGINT, p_tags TEXT[]) + RETURNS VOID LANGUAGE SQL AS $$ + UPDATE source_item + SET tags = + (SELECT ARRAY(SELECT DISTINCT unnest(tags || p_tags))) + WHERE id = p_source; + $$; + ''') + + +def downgrade() -> None: + # Drop indexes + op.drop_index('gh_tsv_idx', table_name='github_item') + op.drop_index('gh_labels_idx', table_name='github_item') + op.drop_index('gh_issue_lookup_idx', table_name='github_item') + op.drop_index('gh_repo_kind_idx', table_name='github_item') + op.drop_index('email_accounts_tags_idx', table_name='email_accounts') + op.drop_index('email_accounts_active_idx', table_name='email_accounts') + op.drop_index('email_accounts_address_idx', table_name='email_accounts') + op.drop_index('rss_feeds_tags_idx', table_name='rss_feeds') + op.drop_index('rss_feeds_active_idx', table_name='rss_feeds') + op.drop_index('photo_taken_idx', table_name='photo') + op.drop_index('git_date_idx', table_name='git_commit') + op.drop_index('git_files_idx', table_name='git_commit') + op.drop_index('chat_channel_idx', table_name='chat_message') + op.drop_index('mail_tsv_idx', table_name='mail_message') + op.drop_index('mail_recipients_idx', table_name='mail_message') + op.drop_index('mail_sent_idx', table_name='mail_message') + op.drop_index('source_tags_idx', table_name='source_item') + op.drop_index('source_status_idx', table_name='source_item') + op.drop_index('source_modality_idx', table_name='source_item') + + # Drop tables + op.drop_table("photo") + op.drop_table("misc_doc") + op.drop_table("mail_message") + op.drop_table("git_commit") + op.drop_table("chat_message") + op.drop_table("book_doc") + op.drop_table("blog_post") + op.drop_table("github_item") + op.drop_table("source_item") + op.drop_table("rss_feeds") + op.drop_table("email_accounts") + + # Drop triggers and functions + op.execute("DROP TRIGGER IF EXISTS check_vector_ids ON source_item") + op.execute("DROP FUNCTION IF EXISTS trg_vector_ids_not_empty()") + op.execute("DROP FUNCTION IF EXISTS add_tags(BIGINT, TEXT[])") + + # Drop enum type + op.execute("DROP TYPE IF EXISTS gh_item_kind") diff --git a/db/schema.sql b/db/schema.sql deleted file mode 100644 index 961a776..0000000 --- a/db/schema.sql +++ /dev/null @@ -1,229 +0,0 @@ -/*======================================================================== - Knowledge-Base schema – first-run script - --------------------------------------------------------------- - • PostgreSQL 15+ - • Creates every table, index, trigger and helper in one pass - • No ALTER statements or later migrations required - • Enable pgcrypto for UUID helpers (safe to re-run) - ========================================================================*/ - -------------------------------------------------------------------------------- --- 0. EXTENSIONS -------------------------------------------------------------------------------- -CREATE EXTENSION IF NOT EXISTS pgcrypto; -- gen_random_uuid(), crypt() - -------------------------------------------------------------------------------- --- 1. CANONICAL ARTEFACT TABLE (everything points here) -------------------------------------------------------------------------------- -CREATE TABLE source_item ( - id BIGSERIAL PRIMARY KEY, - modality TEXT NOT NULL, -- 'mail'|'chat'|... - sha256 BYTEA UNIQUE NOT NULL, -- 32-byte blob - inserted_at TIMESTAMPTZ DEFAULT NOW(), - tags TEXT[] NOT NULL DEFAULT '{}', -- flexible labels - lang TEXT, -- ISO-639-1 or NULL - model_hash TEXT, -- embedding model ver. - vector_ids TEXT[] NOT NULL DEFAULT '{}', -- 0-N Qdrant IDs - embed_status TEXT NOT NULL DEFAULT 'RAW' - CHECK (embed_status IN ('RAW','QUEUED','STORED','FAILED')), - byte_length INTEGER, -- original size - mime_type TEXT -); - -CREATE INDEX source_modality_idx ON source_item (modality); -CREATE INDEX source_status_idx ON source_item (embed_status); -CREATE INDEX source_tags_idx ON source_item USING GIN (tags); - --- 1.a Trigger – vector_ids must be present when status = STORED -CREATE OR REPLACE FUNCTION trg_vector_ids_not_empty() -RETURNS TRIGGER LANGUAGE plpgsql AS $$ -BEGIN - IF NEW.embed_status = 'STORED' - AND (NEW.vector_ids IS NULL OR array_length(NEW.vector_ids,1) = 0) THEN - RAISE EXCEPTION - USING MESSAGE = 'vector_ids must not be empty when embed_status = STORED'; - END IF; - RETURN NEW; -END; -$$; -CREATE TRIGGER check_vector_ids -BEFORE UPDATE ON source_item -FOR EACH ROW EXECUTE FUNCTION trg_vector_ids_not_empty(); - -------------------------------------------------------------------------------- --- 2. MAIL MESSAGES -------------------------------------------------------------------------------- -CREATE TABLE mail_message ( - id BIGSERIAL PRIMARY KEY, - source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE, - message_id TEXT UNIQUE, - subject TEXT, - sender TEXT, - recipients TEXT[], - sent_at TIMESTAMPTZ, - body_raw TEXT, - attachments JSONB -); - -CREATE INDEX mail_sent_idx ON mail_message (sent_at); -CREATE INDEX mail_recipients_idx ON mail_message USING GIN (recipients); - -ALTER TABLE mail_message - ADD COLUMN tsv tsvector - GENERATED ALWAYS AS ( - to_tsvector('english', - coalesce(subject,'') || ' ' || coalesce(body_raw,''))) - STORED; -CREATE INDEX mail_tsv_idx ON mail_message USING GIN (tsv); - -------------------------------------------------------------------------------- --- 3. CHAT (Slack / Discord) -------------------------------------------------------------------------------- -CREATE TABLE chat_message ( - id BIGSERIAL PRIMARY KEY, - source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE, - platform TEXT CHECK (platform IN ('slack','discord')), - channel_id TEXT, - author TEXT, - sent_at TIMESTAMPTZ, - body_raw TEXT -); -CREATE INDEX chat_channel_idx ON chat_message (platform, channel_id); - -------------------------------------------------------------------------------- --- 4. GIT COMMITS (local repos) -------------------------------------------------------------------------------- -CREATE TABLE git_commit ( - id BIGSERIAL PRIMARY KEY, - source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE, - repo_path TEXT, - commit_sha TEXT UNIQUE, - author_name TEXT, - author_email TEXT, - author_date TIMESTAMPTZ, - msg_raw TEXT, - diff_summary TEXT, - files_changed TEXT[] -); -CREATE INDEX git_files_idx ON git_commit USING GIN (files_changed); -CREATE INDEX git_date_idx ON git_commit (author_date); - -------------------------------------------------------------------------------- --- 5. PHOTOS -------------------------------------------------------------------------------- -CREATE TABLE photo ( - id BIGSERIAL PRIMARY KEY, - source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE, - file_path TEXT, - exif_taken_at TIMESTAMPTZ, - exif_lat NUMERIC(9,6), - exif_lon NUMERIC(9,6), - camera_make TEXT, - camera_model TEXT -); -CREATE INDEX photo_taken_idx ON photo (exif_taken_at); - -------------------------------------------------------------------------------- --- 6. BOOKS, BLOG POSTS, MISC DOCS -------------------------------------------------------------------------------- -CREATE TABLE book_doc ( - id BIGSERIAL PRIMARY KEY, - source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE, - title TEXT, - author TEXT, - chapter TEXT, - published DATE -); - -CREATE TABLE blog_post ( - id BIGSERIAL PRIMARY KEY, - source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE, - url TEXT UNIQUE, - title TEXT, - published TIMESTAMPTZ -); - -CREATE TABLE misc_doc ( - id BIGSERIAL PRIMARY KEY, - source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE, - path TEXT, - mime_type TEXT -); - -------------------------------------------------------------------------------- --- 6.5 RSS FEEDS -------------------------------------------------------------------------------- -CREATE TABLE rss_feeds ( - id BIGSERIAL PRIMARY KEY, - url TEXT UNIQUE NOT NULL, - title TEXT, - description TEXT, - tags TEXT[] NOT NULL DEFAULT '{}', - last_checked_at TIMESTAMPTZ, - active BOOLEAN NOT NULL DEFAULT TRUE, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE INDEX rss_feeds_active_idx ON rss_feeds (active, last_checked_at); -CREATE INDEX rss_feeds_tags_idx ON rss_feeds USING GIN (tags); - -------------------------------------------------------------------------------- --- 7. GITHUB ITEMS (issues, PRs, comments, project cards) -------------------------------------------------------------------------------- -CREATE TYPE gh_item_kind AS ENUM ('issue','pr','comment','project_card'); - -CREATE TABLE github_item ( - id BIGSERIAL PRIMARY KEY, - source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE, - - kind gh_item_kind NOT NULL, - repo_path TEXT NOT NULL, -- "owner/repo" - number INTEGER, -- issue/PR number (NULL for commit comment) - parent_number INTEGER, -- comment → its issue/PR - commit_sha TEXT, -- for commit comments - state TEXT, -- 'open'|'closed'|'merged' - title TEXT, - body_raw TEXT, - labels TEXT[], - author TEXT, - created_at TIMESTAMPTZ, - closed_at TIMESTAMPTZ, - merged_at TIMESTAMPTZ, - diff_summary TEXT, -- PR only - - payload JSONB -- extra GitHub fields -); - -CREATE INDEX gh_repo_kind_idx ON github_item (repo_path, kind); -CREATE INDEX gh_issue_lookup_idx ON github_item (repo_path, kind, number); -CREATE INDEX gh_labels_idx ON github_item USING GIN (labels); - -CREATE INDEX gh_tsv_idx ON github_item -WHERE kind IN ('issue','pr') -USING GIN (to_tsvector('english', - coalesce(title,'') || ' ' || coalesce(body_raw,''))); - -------------------------------------------------------------------------------- --- 8. HELPER FUNCTION – add tags -------------------------------------------------------------------------------- -CREATE OR REPLACE FUNCTION add_tags(p_source BIGINT, p_tags TEXT[]) -RETURNS VOID LANGUAGE SQL AS $$ -UPDATE source_item - SET tags = - (SELECT ARRAY(SELECT DISTINCT unnest(tags || p_tags))) - WHERE id = p_source; -$$; - -------------------------------------------------------------------------------- --- 9. (optional) PARTITION STUBS – create per-year partitions later -------------------------------------------------------------------------------- -/* --- example: -CREATE TABLE mail_message_2026 PARTITION OF mail_message - FOR VALUES FROM ('2026-01-01') TO ('2027-01-01'); -*/ - --- ========================================================================= --- Schema creation complete --- ========================================================================= \ No newline at end of file diff --git a/dev.sh b/dev.sh new file mode 100755 index 0000000..582e577 --- /dev/null +++ b/dev.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash +set -eo pipefail + +# Colors for output +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +NC='\033[0m' # No Color + +echo -e "${GREEN}Starting development environment for Memory Knowledge Base...${NC}" + +# Get the directory of the script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Create a temporary docker-compose override file to expose PostgreSQL +echo -e "${YELLOW}Creating docker-compose override to expose PostgreSQL...${NC}" +if [ ! -f docker-compose.override.yml ]; then + cat > docker-compose.override.yml << EOL +version: "3.9" +services: + postgres: + ports: + - "5432:5432" +EOL +fi + +# Start the containers +echo -e "${GREEN}Starting docker containers...${NC}" +docker-compose up -d postgres rabbitmq qdrant + +# Wait for PostgreSQL to be ready +echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}" +for i in {1..30}; do + if docker-compose exec postgres pg_isready -U kb > /dev/null 2>&1; then + echo -e "${GREEN}PostgreSQL is ready!${NC}" + break + fi + echo -n "." + sleep 1 +done + +# Initialize the database if needed +echo -e "${YELLOW}Checking if database needs initialization...${NC}" +if ! docker-compose exec postgres psql -U kb -d kb -c "SELECT 1 FROM information_schema.tables WHERE table_name = 'source_item'" | grep -q 1; then + echo -e "${GREEN}Initializing database from schema.sql...${NC}" + docker-compose exec postgres psql -U kb -d kb -f /docker-entrypoint-initdb.d/schema.sql +else + echo -e "${GREEN}Database already initialized.${NC}" +fi + +echo -e "${GREEN}Development environment is ready!${NC}" +echo -e "${YELLOW}PostgreSQL is available at localhost:5432${NC}" +echo -e "${YELLOW}Username: kb${NC}" +echo -e "${YELLOW}Password: (check secrets/postgres_password.txt)${NC}" +echo -e "${YELLOW}Database: kb${NC}" +echo "" +echo -e "${GREEN}To stop the environment, run:${NC}" +echo -e "${YELLOW}docker-compose down${NC}" \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index d59bd76..65b0675 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -61,7 +61,6 @@ services: secrets: [postgres_password] volumes: - db_data:/var/lib/postgresql/data:rw - - ./db:/docker-entrypoint-initdb.d:ro healthcheck: test: ["CMD-SHELL", "pg_isready -U kb"] interval: 10s diff --git a/requirements-common.txt b/requirements-common.txt index 965c2ba..7f78901 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -1,3 +1,5 @@ sqlalchemy==2.0.30 psycopg2-binary==2.9.9 -pydantic==2.7.1 \ No newline at end of file +pydantic==2.7.1 +alembic==1.13.1 +dotenv==1.1.0 \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..4220417 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,5 @@ +pytest==7.4.4 +pytest-cov==4.1.0 +black==23.12.1 +mypy==1.8.0 +isort==5.13.2 \ No newline at end of file diff --git a/setup.py b/setup.py index 7965157..1c6ccba 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ def read_requirements(filename: str) -> list[str]: common_requires = read_requirements('requirements-common.txt') api_requires = read_requirements('requirements-api.txt') workers_requires = read_requirements('requirements-workers.txt') +dev_requires = read_requirements('requirements-dev.txt') setup( name="memory", @@ -27,5 +28,7 @@ setup( "api": api_requires + common_requires, "workers": workers_requires + common_requires, "common": common_requires, + "dev": dev_requires, + "all": api_requires + workers_requires + common_requires + dev_requires, }, ) \ No newline at end of file diff --git a/src/memory/common/db/connection.py b/src/memory/common/db/connection.py index b21ea44..92e4b7d 100644 --- a/src/memory/common/db/connection.py +++ b/src/memory/common/db/connection.py @@ -29,4 +29,9 @@ def get_scoped_session(): """Create a thread-local scoped session factory""" engine = get_engine() session_factory = sessionmaker(bind=engine) - return scoped_session(session_factory) \ No newline at end of file + return scoped_session(session_factory) + + +def make_session(): + with get_scoped_session() as session: + yield session diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index 20af8bb..878fadb 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -2,8 +2,8 @@ Database models for the knowledge base system. """ from sqlalchemy import ( - Column, ForeignKey, Integer, BigInteger, Text, DateTime, Boolean, Float, - ARRAY, func + Column, ForeignKey, Integer, BigInteger, Text, DateTime, Boolean, + ARRAY, func, Numeric, CheckConstraint, Index ) from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR from sqlalchemy.ext.declarative import declarative_base @@ -26,6 +26,14 @@ class SourceItem(Base): embed_status = Column(Text, nullable=False, server_default='RAW') byte_length = Column(Integer) mime_type = Column(Text) + + # Add table-level constraint and indexes + __table_args__ = ( + CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"), + Index('source_modality_idx', 'modality'), + Index('source_status_idx', 'embed_status'), + Index('source_tags_idx', 'tags', postgresql_using='gin'), + ) class MailMessage(Base): @@ -41,6 +49,13 @@ class MailMessage(Base): body_raw = Column(Text) attachments = Column(JSONB) tsv = Column(TSVECTOR) + + # Add indexes + __table_args__ = ( + Index('mail_sent_idx', 'sent_at'), + Index('mail_recipients_idx', 'recipients', postgresql_using='gin'), + Index('mail_tsv_idx', 'tsv', postgresql_using='gin'), + ) class ChatMessage(Base): @@ -53,6 +68,11 @@ class ChatMessage(Base): author = Column(Text) sent_at = Column(DateTime(timezone=True)) body_raw = Column(Text) + + # Add index + __table_args__ = ( + Index('chat_channel_idx', 'platform', 'channel_id'), + ) class GitCommit(Base): @@ -68,6 +88,12 @@ class GitCommit(Base): msg_raw = Column(Text) diff_summary = Column(Text) files_changed = Column(ARRAY(Text)) + + # Add indexes + __table_args__ = ( + Index('git_files_idx', 'files_changed', postgresql_using='gin'), + Index('git_date_idx', 'author_date'), + ) class Photo(Base): @@ -77,10 +103,14 @@ class Photo(Base): source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False) file_path = Column(Text) exif_taken_at = Column(DateTime(timezone=True)) - exif_lat = Column(Float) - exif_lon = Column(Float) - camera_make = Column(Text) - camera_model = Column(Text) + exif_lat = Column(Numeric(9, 6)) + exif_lon = Column(Numeric(9, 6)) + camera = Column(Text) + + # Add index + __table_args__ = ( + Index('photo_taken_idx', 'exif_taken_at'), + ) class BookDoc(Base): @@ -91,7 +121,7 @@ class BookDoc(Base): title = Column(Text) author = Column(Text) chapter = Column(Text) - published = Column(DateTime) + published = Column(DateTime(timezone=True)) class BlogPost(Base): @@ -124,4 +154,67 @@ class RssFeed(Base): last_checked_at = Column(DateTime(timezone=True)) active = Column(Boolean, nullable=False, server_default='true') created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) - updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) \ No newline at end of file + updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + # Add indexes + __table_args__ = ( + Index('rss_feeds_active_idx', 'active', 'last_checked_at'), + Index('rss_feeds_tags_idx', 'tags', postgresql_using='gin'), + ) + + +class EmailAccount(Base): + __tablename__ = 'email_accounts' + + id = Column(BigInteger, primary_key=True) + name = Column(Text, nullable=False) + email_address = Column(Text, nullable=False, unique=True) + imap_server = Column(Text, nullable=False) + imap_port = Column(Integer, nullable=False, server_default='993') + username = Column(Text, nullable=False) + password = Column(Text, nullable=False) + use_ssl = Column(Boolean, nullable=False, server_default='true') + folders = Column(ARRAY(Text), nullable=False, server_default='{}') + tags = Column(ARRAY(Text), nullable=False, server_default='{}') + last_sync_at = Column(DateTime(timezone=True)) + active = Column(Boolean, nullable=False, server_default='true') + created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + # Add indexes + __table_args__ = ( + Index('email_accounts_address_idx', 'email_address', unique=True), + Index('email_accounts_active_idx', 'active', 'last_sync_at'), + Index('email_accounts_tags_idx', 'tags', postgresql_using='gin'), + ) + + +class GithubItem(Base): + __tablename__ = 'github_item' + + id = Column(BigInteger, primary_key=True) + source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False) + + kind = Column(Text, nullable=False) + repo_path = Column(Text, nullable=False) + number = Column(Integer) + parent_number = Column(Integer) + commit_sha = Column(Text) + state = Column(Text) + title = Column(Text) + body_raw = Column(Text) + labels = Column(ARRAY(Text)) + author = Column(Text) + created_at = Column(DateTime(timezone=True)) + closed_at = Column(DateTime(timezone=True)) + merged_at = Column(DateTime(timezone=True)) + diff_summary = Column(Text) + + payload = Column(JSONB) + + __table_args__ = ( + CheckConstraint("kind IN ('issue', 'pr', 'comment', 'project_card')"), + Index('gh_repo_kind_idx', 'repo_path', 'kind'), + Index('gh_issue_lookup_idx', 'repo_path', 'kind', 'number'), + Index('gh_labels_idx', 'labels', postgresql_using='gin'), + ) \ No newline at end of file diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py new file mode 100644 index 0000000..d06e8c6 --- /dev/null +++ b/src/memory/common/settings.py @@ -0,0 +1,16 @@ +import os +from dotenv import load_dotenv + +load_dotenv() + + +DB_USER = os.getenv("DB_USER", "kb") +DB_PASSWORD = os.getenv("DB_PASSWORD", "kb") +DB_HOST = os.getenv("DB_HOST", "postgres") +DB_PORT = os.getenv("DB_PORT", "5432") +DB_NAME = os.getenv("DB_NAME", "kb") + +def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME): + return f"postgresql://{user}:{password}@{host}:{port}/{db}" + +DB_URL = os.getenv("DATABASE_URL", make_db_url()) \ No newline at end of file diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py new file mode 100644 index 0000000..df858ee --- /dev/null +++ b/src/memory/workers/email.py @@ -0,0 +1,362 @@ +import email +import hashlib +import imaplib +import logging +import re +from contextlib import contextmanager +from datetime import datetime +from email.utils import parsedate_to_datetime + +from sqlalchemy.orm import Session + +from memory.common.db.models import EmailAccount, MailMessage, SourceItem + + +logger = logging.getLogger(__name__) + + +def extract_recipients(msg: email.message.Message) -> list[str]: + """ + Extract email recipients from message headers. + + Args: + msg: Email message object + + Returns: + List of recipient email addresses + """ + return [ + recipient + for field in ["To", "Cc", "Bcc"] + if (field_value := msg.get(field, "")) + for r in field_value.split(",") + if (recipient := r.strip()) + ] + + +def extract_date(msg: email.message.Message) -> datetime | None: + """ + Parse date from email header. + + Args: + msg: Email message object + + Returns: + Parsed datetime or None if parsing failed + """ + if date_str := msg.get("Date"): + try: + return parsedate_to_datetime(date_str) + except Exception: + logger.warning(f"Could not parse date: {date_str}") + return None + + +def extract_body(msg: email.message.Message) -> str: + """ + Extract plain text body from email message. + + Args: + msg: Email message object + + Returns: + Plain text body content + """ + body = "" + + if not msg.is_multipart(): + try: + return msg.get_payload(decode=True).decode(errors='replace') + except Exception as e: + logger.error(f"Error decoding message body: {str(e)}") + return "" + + for part in msg.walk(): + content_type = part.get_content_type() + content_disposition = str(part.get("Content-Disposition", "")) + + if content_type == "text/plain" and "attachment" not in content_disposition: + try: + body += part.get_payload(decode=True).decode(errors='replace') + "\n" + except Exception as e: + logger.error(f"Error decoding message part: {str(e)}") + return body + + +def extract_attachments(msg: email.message.Message) -> list[dict]: + """ + Extract attachment metadata from email. + + Args: + msg: Email message object + + Returns: + List of attachment metadata dicts + """ + if not msg.is_multipart(): + return [] + + attachments = [] + for part in msg.walk(): + content_disposition = part.get("Content-Disposition", "") + if "attachment" not in content_disposition: + continue + + if filename := part.get_filename(): + attachments.append({ + "filename": filename, + "content_type": part.get_content_type(), + "size": len(part.get_payload(decode=True)) + }) + + return attachments + + +def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes: + """ + Compute a SHA-256 hash of message content. + + Args: + msg_id: Message ID + subject: Email subject + sender: Sender email + body: Message body + + Returns: + SHA-256 hash as bytes + """ + hash_content = (msg_id + subject + sender + body).encode() + return hashlib.sha256(hash_content).digest() + + +def parse_email_message(raw_email: str) -> dict: + """ + Parse raw email into structured data. + + Args: + raw_email: Raw email content as string + + Returns: + Dict with parsed email data + """ + msg = email.message_from_string(raw_email) + + return { + "message_id": msg.get("Message-ID", ""), + "subject": msg.get("Subject", ""), + "sender": msg.get("From", ""), + "recipients": extract_recipients(msg), + "sent_at": extract_date(msg), + "body": extract_body(msg), + "attachments": extract_attachments(msg) + } + + +def create_source_item( + db_session: Session, + message_hash: bytes, + account_tags: list[str], + raw_email_size: int, +) -> SourceItem: + """ + Create a new source item record. + + Args: + db_session: Database session + message_hash: SHA-256 hash of message + account_tags: Tags from the email account + raw_email_size: Size of raw email in bytes + + Returns: + Newly created SourceItem + """ + source_item = SourceItem( + modality="mail", + sha256=message_hash, + tags=account_tags, + byte_length=raw_email_size, + mime_type="message/rfc822", + embed_status="RAW" + ) + db_session.add(source_item) + db_session.flush() + return source_item + + +def create_mail_message( + db_session: Session, + source_id: int, + parsed_email: dict, + folder: str, +) -> MailMessage: + """ + Create a new mail message record. + + Args: + db_session: Database session + source_id: ID of the SourceItem + parsed_email: Parsed email data + folder: IMAP folder name + + Returns: + Newly created MailMessage + """ + mail_message = MailMessage( + source_id=source_id, + message_id=parsed_email["message_id"], + subject=parsed_email["subject"], + sender=parsed_email["sender"], + recipients=parsed_email["recipients"], + sent_at=parsed_email["sent_at"], + body_raw=parsed_email["body"], + attachments={"items": parsed_email["attachments"], "folder": folder} + ) + db_session.add(mail_message) + return mail_message + + +def check_message_exists(db_session: Session, message_id: str, message_hash: bytes) -> bool: + """ + Check if a message already exists in the database. + + Args: + db_session: Database session + message_id: Email message ID + message_hash: SHA-256 hash of message + + Returns: + True if message exists, False otherwise + """ + return ( + # Check by message_id first (faster) + message_id and db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first() + # Then check by message_hash + or db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first() is not None + ) + + +def extract_email_uid(msg_data: bytes) -> tuple[str, str]: + """ + Extract the UID and raw email data from the message data. + """ + uid_pattern = re.compile(r'UID (\d+)') + uid_match = uid_pattern.search(msg_data[0][0].decode('utf-8', errors='replace')) + uid = uid_match.group(1) if uid_match else None + raw_email = msg_data[0][1] + return uid, raw_email + + +def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> tuple[str, bytes] | None: + try: + status, msg_data = conn.fetch(uid, '(UID RFC822)') + if status != 'OK' or not msg_data or not msg_data[0]: + logger.error(f"Error fetching message {uid}") + return None + + return extract_email_uid(msg_data) + except Exception as e: + logger.error(f"Error processing message {uid}: {str(e)}") + return None + + +def fetch_email_since( + conn: imaplib.IMAP4_SSL, + folder: str, + since_date: datetime +) -> list[tuple[str, bytes]]: + """ + Fetch emails from a folder since a given date. + + Args: + conn: IMAP connection + folder: Folder name to select + since_date: Fetch emails since this date + + Returns: + List of tuples with (uid, raw_email) + """ + try: + status, counts = conn.select(folder) + if status != 'OK': + logger.error(f"Error selecting folder {folder}: {counts}") + return [] + + date_str = since_date.strftime("%d-%b-%Y") + + status, data = conn.search(None, f'(SINCE "{date_str}")') + if status != 'OK': + logger.error(f"Error searching folder {folder}: {data}") + return [] + except Exception as e: + logger.error(f"Error in fetch_email_since for folder {folder}: {str(e)}") + return [] + + if not data or not data[0]: + return [] + + return [email for uid in data[0].split() if (email := fetch_email(conn, uid))] + + +def process_folder( + conn: imaplib.IMAP4_SSL, + folder: str, + account: EmailAccount, + since_date: datetime +) -> dict: + """ + Process a single folder from an email account. + + Args: + conn: Active IMAP connection + folder: Folder name to process + account: Email account configuration + since_date: Only fetch messages newer than this date + + Returns: + Stats dictionary for the folder + """ + new_messages, errors = 0, 0 + + try: + emails = fetch_email_since(conn, folder, since_date) + + for uid, raw_email in emails: + try: + task = process_message.delay( + account_id=account.id, + message_id=uid, + folder=folder, + raw_email=raw_email.decode('utf-8', errors='replace') + ) + if task: + new_messages += 1 + except Exception as e: + logger.error(f"Error queuing message {uid}: {str(e)}") + errors += 1 + + except Exception as e: + logger.error(f"Error processing folder {folder}: {str(e)}") + errors += 1 + + return { + "messages_found": len(emails), + "new_messages": new_messages, + "errors": errors + } + + +@contextmanager +def imap_connection(account: EmailAccount) -> imaplib.IMAP4_SSL: + conn = imaplib.IMAP4_SSL( + host=account.imap_server, + port=account.imap_port + ) + try: + conn.login(account.username, account.password) + yield conn + finally: + # Always try to logout and close the connection + try: + conn.logout() + except Exception as e: + logger.error(f"Error logging out from {account.imap_server}: {str(e)}") diff --git a/src/memory/workers/tasks/__init__.py b/src/memory/workers/tasks/__init__.py index 40f1244..1b1f3e8 100644 --- a/src/memory/workers/tasks/__init__.py +++ b/src/memory/workers/tasks/__init__.py @@ -1,4 +1,4 @@ """ Import sub-modules so Celery can register their @app.task decorators. """ -from memory.workers.tasks import text, photo, ocr, git, rss, docs # noqa \ No newline at end of file +from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa \ No newline at end of file diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py new file mode 100644 index 0000000..0de775c --- /dev/null +++ b/src/memory/workers/tasks/email.py @@ -0,0 +1,137 @@ +import logging +from datetime import datetime + +from memory.common.db.connection import make_session +from memory.common.db.models import EmailAccount +from memory.workers.celery_app import app +from memory.workers.email import ( + check_message_exists, + compute_message_hash, + create_mail_message, + create_source_item, + imap_connection, + parse_email_message, + process_folder, +) + + +logger = logging.getLogger(__name__) + + +@app.task(name="memory.email.process_message") +def process_message( + account_id: int, message_id: str, folder: str, raw_email: str, +) -> int | None: + """ + Process a single email message and store it in the database. + + Args: + account_id: ID of the EmailAccount + message_id: UID of the message on the server + folder: Folder name where the message is stored + raw_email: Raw email content as string + + Returns: + source_id if successful, None otherwise + """ + with make_session() as db: + account = db.query(EmailAccount).get(account_id) + if not account: + logger.error(f"Account {account_id} not found") + return None + + parsed_email = parse_email_message(raw_email) + + # Use server-provided message ID if missing + if not parsed_email["message_id"]: + parsed_email["message_id"] = f"generated-{message_id}" + + message_hash = compute_message_hash( + parsed_email["message_id"], + parsed_email["subject"], + parsed_email["sender"], + parsed_email["body"] + ) + + if check_message_exists(db, parsed_email["message_id"], message_hash): + logger.debug(f"Message {parsed_email['message_id']} already exists in database") + return None + + source_item = create_source_item(db, message_hash, account.tags, len(raw_email)) + + create_mail_message(db, source_item.id, parsed_email, folder) + + db.commit() + + # TODO: Queue for embedding once that's implemented + + return source_item.id + + +@app.task(name="memory.email.sync_account") +def sync_account(account_id: int) -> dict: + """ + Synchronize emails from a specific account. + + Args: + account_id: ID of the EmailAccount to sync + + Returns: + dict with stats about the sync operation + """ + with make_session() as db: + account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first() + if not account or not account.active: + logger.warning(f"Account {account_id} not found or inactive") + return {"error": "Account not found or inactive"} + + folders_to_process = account.folders or ["INBOX"] + since_date = account.last_sync_at or datetime(1970, 1, 1) + + messages_found = 0 + new_messages = 0 + errors = 0 + + try: + with imap_connection(account) as conn: + for folder in folders_to_process: + folder_stats = process_folder(conn, folder, account, since_date) + + messages_found += folder_stats["messages_found"] + new_messages += folder_stats["new_messages"] + errors += folder_stats["errors"] + + account.last_sync_at = datetime.now() + db.commit() + except Exception as e: + logger.error(f"Error connecting to server {account.imap_server}: {str(e)}") + return {"error": str(e)} + + return { + "account": account.email_address, + "folders_processed": len(folders_to_process), + "messages_found": messages_found, + "new_messages": new_messages, + "errors": errors + } + + +@app.task(name="memory.email.sync_all_accounts") +def sync_all_accounts() -> list[dict]: + """ + Synchronize all active email accounts. + + Returns: + List of task IDs that were scheduled + """ + with make_session() as db: + active_accounts = db.query(EmailAccount).filter(EmailAccount.active).all() + + return [ + { + "account_id": account.id, + "email": account.email_address, + "task_id": sync_account.delay(account.id).id + } + for account in active_accounts + ] \ No newline at end of file diff --git a/tests/memory/workers/tasks/conftest.py b/tests/memory/workers/tasks/conftest.py new file mode 100644 index 0000000..feaed21 --- /dev/null +++ b/tests/memory/workers/tasks/conftest.py @@ -0,0 +1,129 @@ +import os +import subprocess +import uuid +from pathlib import Path + +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker + +from memory.common import settings + + +def get_test_db_name() -> str: + return f"test_db_{uuid.uuid4().hex[:8]}" + + +def create_test_database(test_db_name: str) -> str: + """ + Create a test database with a unique name. + + Args: + test_db_name: Name for the test database + + Returns: + URL to the test database + """ + admin_engine = create_engine(settings.DB_URL) + + # Create a new database + with admin_engine.connect() as conn: + conn.execute(text("COMMIT")) # Close any open transaction + conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}")) + conn.execute(text(f"CREATE DATABASE {test_db_name}")) + + admin_engine.dispose() + + return settings.make_db_url(db=test_db_name) + + +def drop_test_database(test_db_name: str) -> None: + """ + Drop the test database. + + Args: + test_db_name: Name of the test database to drop + """ + admin_engine = create_engine(settings.DB_URL) + + with admin_engine.connect() as conn: + conn.execute(text("COMMIT")) # Close any open transaction + conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}")) + + +def run_alembic_migrations(db_name: str) -> None: + """Run all Alembic migrations on the test database.""" + project_root = Path(__file__).parent.parent.parent.parent.parent + alembic_ini = project_root / "db" / "migrations" / "alembic.ini" + + breakpoint() + subprocess.run( + ["alembic", "-c", str(alembic_ini), "upgrade", "head"], + env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)}, + check=True, + capture_output=True, + ) + + +@pytest.fixture +def test_db(): + """ + Create a test database, run migrations, and clean up afterwards. + + Returns: + The URL to the test database + """ + test_db_name = get_test_db_name() + + # Create test database + test_db_url = create_test_database(test_db_name) + + try: + run_alembic_migrations(test_db_name) + + # Return the URL to the test database + yield test_db_url + finally: + # Clean up - drop the test database + drop_test_database(test_db_name) + + +@pytest.fixture +def db_engine(test_db): + """ + Create a SQLAlchemy engine connected to the test database. + + Args: + test_db: URL to the test database (from the test_db fixture) + + Returns: + SQLAlchemy engine + """ + engine = create_engine(test_db) + yield engine + engine.dispose() + + +@pytest.fixture +def db_session(db_engine): + """ + Create a new database session for a test. + + Args: + db_engine: SQLAlchemy engine (from the db_engine fixture) + + Returns: + SQLAlchemy session + """ + # Create a new sessionmaker + SessionLocal = sessionmaker(bind=db_engine, autocommit=False, autoflush=False) + + # Create a new session + session = SessionLocal() + + try: + yield session + finally: + # Close and rollback the session after the test is done + session.rollback() + session.close() diff --git a/tests/memory/workers/tasks/test_email.py b/tests/memory/workers/tasks/test_email.py new file mode 100644 index 0000000..0f52083 --- /dev/null +++ b/tests/memory/workers/tasks/test_email.py @@ -0,0 +1,325 @@ +import email +import email.mime.multipart +import email.mime.text +import email.mime.base +from datetime import datetime +from email.utils import formatdate +from unittest.mock import ANY +import pytest + +from memory.common.db.models import SourceItem +from memory.workers.email import ( + compute_message_hash, + create_source_item, + extract_attachments, + extract_body, + extract_date, + extract_email_uid, + extract_recipients, + parse_email_message, +) + + +# Use a simple counter to generate unique message IDs without calling make_msgid +_msg_id_counter = 0 + + +def _generate_test_message_id(): + """Generate a simple message ID for testing without expensive calls""" + global _msg_id_counter + _msg_id_counter += 1 + return f"" + + +def create_email_message( + subject="Test Subject", + from_addr="sender@example.com", + to_addrs="recipient@example.com", + cc_addrs=None, + bcc_addrs=None, + date=None, + body="Test body content", + attachments=None, + multipart=True, + message_id=None, +): + """Helper function to create email.message.Message objects for testing""" + if multipart: + msg = email.mime.multipart.MIMEMultipart() + msg.attach(email.mime.text.MIMEText(body)) + + if attachments: + for attachment in attachments: + attachment_part = email.mime.base.MIMEBase("application", "octet-stream") + attachment_part.set_payload(attachment["content"]) + attachment_part.add_header( + "Content-Disposition", + f"attachment; filename={attachment['filename']}" + ) + msg.attach(attachment_part) + else: + msg = email.mime.text.MIMEText(body) + + msg["Subject"] = subject + msg["From"] = from_addr + msg["To"] = to_addrs + + if cc_addrs: + msg["Cc"] = cc_addrs + if bcc_addrs: + msg["Bcc"] = bcc_addrs + if date: + msg["Date"] = formatdate(float(date.timestamp())) + if message_id: + msg["Message-ID"] = message_id + else: + msg["Message-ID"] = _generate_test_message_id() + + return msg + + +@pytest.mark.parametrize( + "to_addr, cc_addr, bcc_addr, expected", + [ + # Single recipient in To field + ( + "recipient@example.com", + None, + None, + ["recipient@example.com"] + ), + # Multiple recipients in To field + ( + "recipient1@example.com, recipient2@example.com", + None, + None, + ["recipient1@example.com", "recipient2@example.com"] + ), + # To, Cc fields + ( + "recipient@example.com", + "cc@example.com", + None, + ["recipient@example.com", "cc@example.com"] + ), + # To, Cc, Bcc fields + ( + "recipient@example.com", + "cc@example.com", + "bcc@example.com", + ["recipient@example.com", "cc@example.com", "bcc@example.com"] + ), + # Empty fields + ( + "", + "", + "", + [] + ), + ] +) +def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected): + msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr) + assert sorted(extract_recipients(msg)) == sorted(expected) + + +def test_extract_date_missing(): + msg = create_email_message(date=None) + assert extract_date(msg) is None + + +@pytest.mark.parametrize( + "date_str", + [ + "Invalid Date Format", + "2023-01-01", # ISO format but not RFC compliant + "Monday, Jan 1, 2023", # Descriptive but not RFC compliant + "01/01/2023", # Common format but not RFC compliant + "", # Empty string + ] +) +def test_extract_date_invalid_formats(date_str): + msg = create_email_message() + msg["Date"] = date_str + assert extract_date(msg) is None + + +@pytest.mark.parametrize( + "date_str", + [ + "Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format + "01 Jan 2023 12:00:00 +0000", # RFC 822 format + "Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name + ] +) +def test_extract_date(date_str): + msg = create_email_message() + msg["Date"] = date_str + result = extract_date(msg) + + assert result is not None + assert result.year == 2023 + assert result.month == 1 + assert result.day == 1 + + +@pytest.mark.parametrize('multipart', [True, False]) +def test_extract_body_text_plain(multipart): + body_content = "This is a test email body" + msg = create_email_message(body=body_content, multipart=multipart) + extracted = extract_body(msg) + + # Strip newlines for comparison since multipart emails often add them + assert extracted.strip() == body_content.strip() + + +def test_extract_body_with_attachments(): + body_content = "This is a test email body" + attachments = [ + {"filename": "test.txt", "content": b"attachment content"} + ] + msg = create_email_message(body=body_content, attachments=attachments) + assert body_content in extract_body(msg) + + +def test_extract_attachments_none(): + msg = create_email_message(multipart=True) + assert extract_attachments(msg) == [] + + +def test_extract_attachments_with_files(): + attachments = [ + {"filename": "test1.txt", "content": b"content1"}, + {"filename": "test2.pdf", "content": b"content2"} + ] + msg = create_email_message(attachments=attachments) + + result = extract_attachments(msg) + assert len(result) == 2 + assert result[0]["filename"] == "test1.txt" + assert result[1]["filename"] == "test2.pdf" + + +def test_extract_attachments_non_multipart(): + msg = create_email_message(multipart=False) + assert extract_attachments(msg) == [] + + +@pytest.mark.parametrize( + "msg_id, subject, sender, body, expected", + [ + ( + "", + "Test Subject", + "sender@example.com", + "Test body", + b"\xf2\xbd" # First two bytes of the actual hash + ), + ( + "", + "Test Subject", + "sender@example.com", + "Test body", + b"\xa4\x15" # Will be different from the first hash + ), + ] +) +def test_compute_message_hash(msg_id, subject, sender, body, expected): + result = compute_message_hash(msg_id, subject, sender, body) + + # Verify it's bytes and correct length for SHA-256 (32 bytes) + assert isinstance(result, bytes) + assert len(result) == 32 + + # Verify first two bytes match expected + assert result[:2] == expected + + +def test_hash_consistency(): + args = ("", "Test Subject", "sender@example.com", "Test body") + assert compute_message_hash(*args) == compute_message_hash(*args) + + +def test_parse_simple_email(): + test_date = datetime(2023, 1, 1, 12, 0, 0) + msg_id = "" + msg = create_email_message( + subject="Test Subject", + from_addr="sender@example.com", + to_addrs="recipient@example.com", + date=test_date, + body="Test body content", + message_id=msg_id + ) + + result = parse_email_message(msg.as_string()) + + assert result == { + "message_id": msg_id, + "subject": "Test Subject", + "sender": "sender@example.com", + "recipients": ["recipient@example.com"], + "body": "Test body content\n", + "attachments": [], + "sent_at": ANY, + } + assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400 + + +def test_parse_email_with_attachments(): + attachments = [ + {"filename": "test.txt", "content": b"attachment content"} + ] + msg = create_email_message(attachments=attachments) + + result = parse_email_message(msg.as_string()) + + assert len(result["attachments"]) == 1 + assert result["attachments"][0]["filename"] == "test.txt" + + +def test_extract_email_uid_valid(): + msg_data = [(b'1 (UID 12345 RFC822 {1234}', b'raw email content')] + uid, raw_email = extract_email_uid(msg_data) + + assert uid == "12345" + assert raw_email == b'raw email content' + + +def test_extract_email_uid_no_match(): + msg_data = [(b'1 (RFC822 {1234}', b'raw email content')] + uid, raw_email = extract_email_uid(msg_data) + + assert uid is None + assert raw_email == b'raw email content' + + +def test_create_source_item(db_session): + # Mock data + message_hash = b'test_hash_bytes' + bytes(28) # 32 bytes for SHA-256 + account_tags = ["work", "important"] + raw_email_size = 1024 + + # Call function + source_item = create_source_item( + db_session=db_session, + message_hash=message_hash, + account_tags=account_tags, + raw_email_size=raw_email_size + ) + + # Verify the source item was created correctly + assert isinstance(source_item, SourceItem) + assert source_item.id is not None + assert source_item.modality == "mail" + assert source_item.sha256 == message_hash + assert source_item.tags == account_tags + assert source_item.byte_length == raw_email_size + assert source_item.mime_type == "message/rfc822" + assert source_item.embed_status == "RAW" + + # Verify it was added to the session + db_session.flush() + fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one() + assert fetched_item is not None + assert fetched_item.sha256 == message_hash