mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 05:14:43 +02:00
alembic + tests
This commit is contained in:
parent
a003ada9b7
commit
d1cac9ffd9
13
.gitignore
vendored
13
.gitignore
vendored
@ -2,4 +2,15 @@
|
||||
.DS_Store
|
||||
secrets/
|
||||
acme.json
|
||||
__pycache__/
|
||||
__pycache__/
|
||||
*.egg-info/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.pyw
|
||||
*.pyz
|
||||
*.pywz
|
||||
*.pyzw
|
||||
|
||||
docker-compose.override.yml
|
||||
docker/pgadmin
|
107
db/migrations/alembic.ini
Normal file
107
db/migrations/alembic.ini
Normal file
@ -0,0 +1,107 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = db/migrations
|
||||
|
||||
# template used to generate migration files
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d%%(second).2d_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to migrations/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
version_locations = db/migrations/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# URL of the database to connect to
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
# enable compare_type (to detect column type changes)
|
||||
# and compare_server_default (to detect changed in server default values)
|
||||
compare_type = true
|
||||
compare_server_default = true
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
82
db/migrations/env.py
Normal file
82
db/migrations/env.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""
|
||||
Alembic environment configuration.
|
||||
"""
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
from memory.common import settings
|
||||
from memory.common.db.models import Base
|
||||
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# setup database URL from environment variables
|
||||
config.set_main_option("sqlalchemy.url", settings.DB_URL)
|
||||
|
||||
# Interpret the config file for Python logging
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode - creates SQL scripts.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""
|
||||
Run migrations in 'online' mode - directly to the database.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
26
db/migrations/script.py.mako
Normal file
26
db/migrations/script.py.mako
Normal file
@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
341
db/migrations/versions/20250427_171537_initial_structure.py
Normal file
341
db/migrations/versions/20250427_171537_initial_structure.py
Normal file
@ -0,0 +1,341 @@
|
||||
"""Initial structure
|
||||
|
||||
Revision ID: a466a07360d5
|
||||
Revises:
|
||||
Create Date: 2025-04-27 17:15:37.487616
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a466a07360d5"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS pgcrypto')
|
||||
|
||||
# Create enum type for github_item with IF NOT EXISTS
|
||||
op.execute("DO $$ BEGIN CREATE TYPE gh_item_kind AS ENUM ('issue','pr','comment','project_card'); EXCEPTION WHEN duplicate_object THEN NULL; END $$;")
|
||||
|
||||
op.create_table(
|
||||
"email_accounts",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("email_address", sa.Text(), nullable=False),
|
||||
sa.Column("imap_server", sa.Text(), nullable=False),
|
||||
sa.Column("imap_port", sa.Integer(), server_default="993", nullable=False),
|
||||
sa.Column("username", sa.Text(), nullable=False),
|
||||
sa.Column("password", sa.Text(), nullable=False),
|
||||
sa.Column("use_ssl", sa.Boolean(), server_default="true", nullable=False),
|
||||
sa.Column("folders", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
|
||||
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
|
||||
sa.Column("last_sync_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), server_default="true", nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email_address"),
|
||||
)
|
||||
op.create_table(
|
||||
"rss_feeds",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("url", sa.Text(), nullable=False),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
|
||||
sa.Column("last_checked_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), server_default="true", nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("url"),
|
||||
)
|
||||
op.create_table(
|
||||
"source_item",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("modality", sa.Text(), nullable=False),
|
||||
sa.Column("sha256", postgresql.BYTEA(), nullable=False),
|
||||
sa.Column(
|
||||
"inserted_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
|
||||
sa.Column("lang", sa.Text(), nullable=True),
|
||||
sa.Column("model_hash", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"vector_ids", sa.ARRAY(sa.Text()), server_default="{}", nullable=False
|
||||
),
|
||||
sa.Column("embed_status", sa.Text(), server_default="RAW", nullable=False),
|
||||
sa.Column("byte_length", sa.Integer(), nullable=True),
|
||||
sa.Column("mime_type", sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("sha256"),
|
||||
)
|
||||
op.create_table(
|
||||
"blog_post",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("url", sa.Text(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("url"),
|
||||
)
|
||||
op.create_table(
|
||||
"book_doc",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("author", sa.Text(), nullable=True),
|
||||
sa.Column("chapter", sa.Text(), nullable=True),
|
||||
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"chat_message",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("platform", sa.Text(), nullable=True),
|
||||
sa.Column("channel_id", sa.Text(), nullable=True),
|
||||
sa.Column("author", sa.Text(), nullable=True),
|
||||
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("body_raw", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"git_commit",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("repo_path", sa.Text(), nullable=True),
|
||||
sa.Column("commit_sha", sa.Text(), nullable=True),
|
||||
sa.Column("author_name", sa.Text(), nullable=True),
|
||||
sa.Column("author_email", sa.Text(), nullable=True),
|
||||
sa.Column("author_date", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("msg_raw", sa.Text(), nullable=True),
|
||||
sa.Column("diff_summary", sa.Text(), nullable=True),
|
||||
sa.Column("files_changed", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("commit_sha"),
|
||||
)
|
||||
op.create_table(
|
||||
"mail_message",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_id", sa.Text(), nullable=True),
|
||||
sa.Column("subject", sa.Text(), nullable=True),
|
||||
sa.Column("sender", sa.Text(), nullable=True),
|
||||
sa.Column("recipients", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("sent_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("body_raw", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"attachments", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.Column("tsv", postgresql.TSVECTOR(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("message_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"misc_doc",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("path", sa.Text(), nullable=True),
|
||||
sa.Column("mime_type", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"photo",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=True),
|
||||
sa.Column("exif_taken_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("exif_lat", sa.Numeric(9, 6), nullable=True),
|
||||
sa.Column("exif_lon", sa.Numeric(9, 6), nullable=True),
|
||||
sa.Column("camera", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Add github_item table
|
||||
op.create_table(
|
||||
"github_item",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("kind", sa.Text(), nullable=False),
|
||||
sa.Column("repo_path", sa.Text(), nullable=False),
|
||||
sa.Column("number", sa.Integer(), nullable=True),
|
||||
sa.Column("parent_number", sa.Integer(), nullable=True),
|
||||
sa.Column("commit_sha", sa.Text(), nullable=True),
|
||||
sa.Column("state", sa.Text(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("body_raw", sa.Text(), nullable=True),
|
||||
sa.Column("labels", sa.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("author", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("closed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("merged_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("diff_summary", sa.Text(), nullable=True),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Add constraint to github_item.kind
|
||||
op.create_check_constraint(
|
||||
"github_item_kind_check",
|
||||
"github_item",
|
||||
"kind IN ('issue', 'pr', 'comment', 'project_card')"
|
||||
)
|
||||
|
||||
# Add missing constraint to source_item
|
||||
op.create_check_constraint(
|
||||
"source_item_embed_status_check",
|
||||
"source_item",
|
||||
"embed_status IN ('RAW','QUEUED','STORED','FAILED')"
|
||||
)
|
||||
|
||||
# Create trigger function for vector_ids validation
|
||||
op.execute('''
|
||||
CREATE OR REPLACE FUNCTION trg_vector_ids_not_empty()
|
||||
RETURNS TRIGGER LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.embed_status = 'STORED'
|
||||
AND (NEW.vector_ids IS NULL OR array_length(NEW.vector_ids,1) = 0) THEN
|
||||
RAISE EXCEPTION
|
||||
USING MESSAGE = 'vector_ids must not be empty when embed_status = STORED';
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
''')
|
||||
|
||||
# Create trigger
|
||||
op.execute('''
|
||||
CREATE TRIGGER check_vector_ids
|
||||
BEFORE UPDATE ON source_item
|
||||
FOR EACH ROW EXECUTE FUNCTION trg_vector_ids_not_empty();
|
||||
''')
|
||||
|
||||
# Create indexes for source_item
|
||||
op.create_index('source_modality_idx', 'source_item', ['modality'])
|
||||
op.create_index('source_status_idx', 'source_item', ['embed_status'])
|
||||
op.create_index('source_tags_idx', 'source_item', ['tags'], postgresql_using='gin')
|
||||
|
||||
# Create indexes for mail_message
|
||||
op.create_index('mail_sent_idx', 'mail_message', ['sent_at'])
|
||||
op.create_index('mail_recipients_idx', 'mail_message', ['recipients'], postgresql_using='gin')
|
||||
op.create_index('mail_tsv_idx', 'mail_message', ['tsv'], postgresql_using='gin')
|
||||
|
||||
# Create index for chat_message
|
||||
op.create_index('chat_channel_idx', 'chat_message', ['platform', 'channel_id'])
|
||||
|
||||
# Create indexes for git_commit
|
||||
op.create_index('git_files_idx', 'git_commit', ['files_changed'], postgresql_using='gin')
|
||||
op.create_index('git_date_idx', 'git_commit', ['author_date'])
|
||||
|
||||
# Create index for photo
|
||||
op.create_index('photo_taken_idx', 'photo', ['exif_taken_at'])
|
||||
|
||||
# Create indexes for rss_feeds
|
||||
op.create_index('rss_feeds_active_idx', 'rss_feeds', ['active', 'last_checked_at'])
|
||||
op.create_index('rss_feeds_tags_idx', 'rss_feeds', ['tags'], postgresql_using='gin')
|
||||
|
||||
# Create indexes for email_accounts
|
||||
op.create_index('email_accounts_address_idx', 'email_accounts', ['email_address'], unique=True)
|
||||
op.create_index('email_accounts_active_idx', 'email_accounts', ['active', 'last_sync_at'])
|
||||
op.create_index('email_accounts_tags_idx', 'email_accounts', ['tags'], postgresql_using='gin')
|
||||
|
||||
# Create indexes for github_item
|
||||
op.create_index('gh_repo_kind_idx', 'github_item', ['repo_path', 'kind'])
|
||||
op.create_index('gh_issue_lookup_idx', 'github_item', ['repo_path', 'kind', 'number'])
|
||||
op.create_index('gh_labels_idx', 'github_item', ['labels'], postgresql_using='gin')
|
||||
|
||||
# Create add_tags helper function
|
||||
op.execute('''
|
||||
CREATE OR REPLACE FUNCTION add_tags(p_source BIGINT, p_tags TEXT[])
|
||||
RETURNS VOID LANGUAGE SQL AS $$
|
||||
UPDATE source_item
|
||||
SET tags =
|
||||
(SELECT ARRAY(SELECT DISTINCT unnest(tags || p_tags)))
|
||||
WHERE id = p_source;
|
||||
$$;
|
||||
''')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index('gh_tsv_idx', table_name='github_item')
|
||||
op.drop_index('gh_labels_idx', table_name='github_item')
|
||||
op.drop_index('gh_issue_lookup_idx', table_name='github_item')
|
||||
op.drop_index('gh_repo_kind_idx', table_name='github_item')
|
||||
op.drop_index('email_accounts_tags_idx', table_name='email_accounts')
|
||||
op.drop_index('email_accounts_active_idx', table_name='email_accounts')
|
||||
op.drop_index('email_accounts_address_idx', table_name='email_accounts')
|
||||
op.drop_index('rss_feeds_tags_idx', table_name='rss_feeds')
|
||||
op.drop_index('rss_feeds_active_idx', table_name='rss_feeds')
|
||||
op.drop_index('photo_taken_idx', table_name='photo')
|
||||
op.drop_index('git_date_idx', table_name='git_commit')
|
||||
op.drop_index('git_files_idx', table_name='git_commit')
|
||||
op.drop_index('chat_channel_idx', table_name='chat_message')
|
||||
op.drop_index('mail_tsv_idx', table_name='mail_message')
|
||||
op.drop_index('mail_recipients_idx', table_name='mail_message')
|
||||
op.drop_index('mail_sent_idx', table_name='mail_message')
|
||||
op.drop_index('source_tags_idx', table_name='source_item')
|
||||
op.drop_index('source_status_idx', table_name='source_item')
|
||||
op.drop_index('source_modality_idx', table_name='source_item')
|
||||
|
||||
# Drop tables
|
||||
op.drop_table("photo")
|
||||
op.drop_table("misc_doc")
|
||||
op.drop_table("mail_message")
|
||||
op.drop_table("git_commit")
|
||||
op.drop_table("chat_message")
|
||||
op.drop_table("book_doc")
|
||||
op.drop_table("blog_post")
|
||||
op.drop_table("github_item")
|
||||
op.drop_table("source_item")
|
||||
op.drop_table("rss_feeds")
|
||||
op.drop_table("email_accounts")
|
||||
|
||||
# Drop triggers and functions
|
||||
op.execute("DROP TRIGGER IF EXISTS check_vector_ids ON source_item")
|
||||
op.execute("DROP FUNCTION IF EXISTS trg_vector_ids_not_empty()")
|
||||
op.execute("DROP FUNCTION IF EXISTS add_tags(BIGINT, TEXT[])")
|
||||
|
||||
# Drop enum type
|
||||
op.execute("DROP TYPE IF EXISTS gh_item_kind")
|
229
db/schema.sql
229
db/schema.sql
@ -1,229 +0,0 @@
|
||||
/*========================================================================
|
||||
Knowledge-Base schema – first-run script
|
||||
---------------------------------------------------------------
|
||||
• PostgreSQL 15+
|
||||
• Creates every table, index, trigger and helper in one pass
|
||||
• No ALTER statements or later migrations required
|
||||
• Enable pgcrypto for UUID helpers (safe to re-run)
|
||||
========================================================================*/
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 0. EXTENSIONS
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE EXTENSION IF NOT EXISTS pgcrypto; -- gen_random_uuid(), crypt()
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 1. CANONICAL ARTEFACT TABLE (everything points here)
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE TABLE source_item (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
modality TEXT NOT NULL, -- 'mail'|'chat'|...
|
||||
sha256 BYTEA UNIQUE NOT NULL, -- 32-byte blob
|
||||
inserted_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
tags TEXT[] NOT NULL DEFAULT '{}', -- flexible labels
|
||||
lang TEXT, -- ISO-639-1 or NULL
|
||||
model_hash TEXT, -- embedding model ver.
|
||||
vector_ids TEXT[] NOT NULL DEFAULT '{}', -- 0-N Qdrant IDs
|
||||
embed_status TEXT NOT NULL DEFAULT 'RAW'
|
||||
CHECK (embed_status IN ('RAW','QUEUED','STORED','FAILED')),
|
||||
byte_length INTEGER, -- original size
|
||||
mime_type TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX source_modality_idx ON source_item (modality);
|
||||
CREATE INDEX source_status_idx ON source_item (embed_status);
|
||||
CREATE INDEX source_tags_idx ON source_item USING GIN (tags);
|
||||
|
||||
-- 1.a Trigger – vector_ids must be present when status = STORED
|
||||
CREATE OR REPLACE FUNCTION trg_vector_ids_not_empty()
|
||||
RETURNS TRIGGER LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.embed_status = 'STORED'
|
||||
AND (NEW.vector_ids IS NULL OR array_length(NEW.vector_ids,1) = 0) THEN
|
||||
RAISE EXCEPTION
|
||||
USING MESSAGE = 'vector_ids must not be empty when embed_status = STORED';
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
CREATE TRIGGER check_vector_ids
|
||||
BEFORE UPDATE ON source_item
|
||||
FOR EACH ROW EXECUTE FUNCTION trg_vector_ids_not_empty();
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 2. MAIL MESSAGES
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE TABLE mail_message (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
|
||||
message_id TEXT UNIQUE,
|
||||
subject TEXT,
|
||||
sender TEXT,
|
||||
recipients TEXT[],
|
||||
sent_at TIMESTAMPTZ,
|
||||
body_raw TEXT,
|
||||
attachments JSONB
|
||||
);
|
||||
|
||||
CREATE INDEX mail_sent_idx ON mail_message (sent_at);
|
||||
CREATE INDEX mail_recipients_idx ON mail_message USING GIN (recipients);
|
||||
|
||||
ALTER TABLE mail_message
|
||||
ADD COLUMN tsv tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
to_tsvector('english',
|
||||
coalesce(subject,'') || ' ' || coalesce(body_raw,'')))
|
||||
STORED;
|
||||
CREATE INDEX mail_tsv_idx ON mail_message USING GIN (tsv);
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 3. CHAT (Slack / Discord)
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE TABLE chat_message (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
|
||||
platform TEXT CHECK (platform IN ('slack','discord')),
|
||||
channel_id TEXT,
|
||||
author TEXT,
|
||||
sent_at TIMESTAMPTZ,
|
||||
body_raw TEXT
|
||||
);
|
||||
CREATE INDEX chat_channel_idx ON chat_message (platform, channel_id);
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 4. GIT COMMITS (local repos)
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE TABLE git_commit (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
|
||||
repo_path TEXT,
|
||||
commit_sha TEXT UNIQUE,
|
||||
author_name TEXT,
|
||||
author_email TEXT,
|
||||
author_date TIMESTAMPTZ,
|
||||
msg_raw TEXT,
|
||||
diff_summary TEXT,
|
||||
files_changed TEXT[]
|
||||
);
|
||||
CREATE INDEX git_files_idx ON git_commit USING GIN (files_changed);
|
||||
CREATE INDEX git_date_idx ON git_commit (author_date);
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 5. PHOTOS
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE TABLE photo (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
|
||||
file_path TEXT,
|
||||
exif_taken_at TIMESTAMPTZ,
|
||||
exif_lat NUMERIC(9,6),
|
||||
exif_lon NUMERIC(9,6),
|
||||
camera_make TEXT,
|
||||
camera_model TEXT
|
||||
);
|
||||
CREATE INDEX photo_taken_idx ON photo (exif_taken_at);
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 6. BOOKS, BLOG POSTS, MISC DOCS
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE TABLE book_doc (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
|
||||
title TEXT,
|
||||
author TEXT,
|
||||
chapter TEXT,
|
||||
published DATE
|
||||
);
|
||||
|
||||
CREATE TABLE blog_post (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
|
||||
url TEXT UNIQUE,
|
||||
title TEXT,
|
||||
published TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE TABLE misc_doc (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
|
||||
path TEXT,
|
||||
mime_type TEXT
|
||||
);
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 6.5 RSS FEEDS
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE TABLE rss_feeds (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
url TEXT UNIQUE NOT NULL,
|
||||
title TEXT,
|
||||
description TEXT,
|
||||
tags TEXT[] NOT NULL DEFAULT '{}',
|
||||
last_checked_at TIMESTAMPTZ,
|
||||
active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX rss_feeds_active_idx ON rss_feeds (active, last_checked_at);
|
||||
CREATE INDEX rss_feeds_tags_idx ON rss_feeds USING GIN (tags);
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 7. GITHUB ITEMS (issues, PRs, comments, project cards)
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE TYPE gh_item_kind AS ENUM ('issue','pr','comment','project_card');
|
||||
|
||||
CREATE TABLE github_item (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
source_id BIGINT NOT NULL REFERENCES source_item ON DELETE CASCADE,
|
||||
|
||||
kind gh_item_kind NOT NULL,
|
||||
repo_path TEXT NOT NULL, -- "owner/repo"
|
||||
number INTEGER, -- issue/PR number (NULL for commit comment)
|
||||
parent_number INTEGER, -- comment → its issue/PR
|
||||
commit_sha TEXT, -- for commit comments
|
||||
state TEXT, -- 'open'|'closed'|'merged'
|
||||
title TEXT,
|
||||
body_raw TEXT,
|
||||
labels TEXT[],
|
||||
author TEXT,
|
||||
created_at TIMESTAMPTZ,
|
||||
closed_at TIMESTAMPTZ,
|
||||
merged_at TIMESTAMPTZ,
|
||||
diff_summary TEXT, -- PR only
|
||||
|
||||
payload JSONB -- extra GitHub fields
|
||||
);
|
||||
|
||||
CREATE INDEX gh_repo_kind_idx ON github_item (repo_path, kind);
|
||||
CREATE INDEX gh_issue_lookup_idx ON github_item (repo_path, kind, number);
|
||||
CREATE INDEX gh_labels_idx ON github_item USING GIN (labels);
|
||||
|
||||
CREATE INDEX gh_tsv_idx ON github_item
|
||||
WHERE kind IN ('issue','pr')
|
||||
USING GIN (to_tsvector('english',
|
||||
coalesce(title,'') || ' ' || coalesce(body_raw,'')));
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 8. HELPER FUNCTION – add tags
|
||||
-------------------------------------------------------------------------------
|
||||
CREATE OR REPLACE FUNCTION add_tags(p_source BIGINT, p_tags TEXT[])
|
||||
RETURNS VOID LANGUAGE SQL AS $$
|
||||
UPDATE source_item
|
||||
SET tags =
|
||||
(SELECT ARRAY(SELECT DISTINCT unnest(tags || p_tags)))
|
||||
WHERE id = p_source;
|
||||
$$;
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- 9. (optional) PARTITION STUBS – create per-year partitions later
|
||||
-------------------------------------------------------------------------------
|
||||
/*
|
||||
-- example:
|
||||
CREATE TABLE mail_message_2026 PARTITION OF mail_message
|
||||
FOR VALUES FROM ('2026-01-01') TO ('2027-01-01');
|
||||
*/
|
||||
|
||||
-- =========================================================================
|
||||
-- Schema creation complete
|
||||
-- =========================================================================
|
59
dev.sh
Executable file
59
dev.sh
Executable file
@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env bash
|
||||
set -eo pipefail
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo -e "${GREEN}Starting development environment for Memory Knowledge Base...${NC}"
|
||||
|
||||
# Get the directory of the script
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Create a temporary docker-compose override file to expose PostgreSQL
|
||||
echo -e "${YELLOW}Creating docker-compose override to expose PostgreSQL...${NC}"
|
||||
if [ ! -f docker-compose.override.yml ]; then
|
||||
cat > docker-compose.override.yml << EOL
|
||||
version: "3.9"
|
||||
services:
|
||||
postgres:
|
||||
ports:
|
||||
- "5432:5432"
|
||||
EOL
|
||||
fi
|
||||
|
||||
# Start the containers
|
||||
echo -e "${GREEN}Starting docker containers...${NC}"
|
||||
docker-compose up -d postgres rabbitmq qdrant
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}"
|
||||
for i in {1..30}; do
|
||||
if docker-compose exec postgres pg_isready -U kb > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}PostgreSQL is ready!${NC}"
|
||||
break
|
||||
fi
|
||||
echo -n "."
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Initialize the database if needed
|
||||
echo -e "${YELLOW}Checking if database needs initialization...${NC}"
|
||||
if ! docker-compose exec postgres psql -U kb -d kb -c "SELECT 1 FROM information_schema.tables WHERE table_name = 'source_item'" | grep -q 1; then
|
||||
echo -e "${GREEN}Initializing database from schema.sql...${NC}"
|
||||
docker-compose exec postgres psql -U kb -d kb -f /docker-entrypoint-initdb.d/schema.sql
|
||||
else
|
||||
echo -e "${GREEN}Database already initialized.${NC}"
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Development environment is ready!${NC}"
|
||||
echo -e "${YELLOW}PostgreSQL is available at localhost:5432${NC}"
|
||||
echo -e "${YELLOW}Username: kb${NC}"
|
||||
echo -e "${YELLOW}Password: (check secrets/postgres_password.txt)${NC}"
|
||||
echo -e "${YELLOW}Database: kb${NC}"
|
||||
echo ""
|
||||
echo -e "${GREEN}To stop the environment, run:${NC}"
|
||||
echo -e "${YELLOW}docker-compose down${NC}"
|
@ -61,7 +61,6 @@ services:
|
||||
secrets: [postgres_password]
|
||||
volumes:
|
||||
- db_data:/var/lib/postgresql/data:rw
|
||||
- ./db:/docker-entrypoint-initdb.d:ro
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U kb"]
|
||||
interval: 10s
|
||||
|
@ -1,3 +1,5 @@
|
||||
sqlalchemy==2.0.30
|
||||
psycopg2-binary==2.9.9
|
||||
pydantic==2.7.1
|
||||
pydantic==2.7.1
|
||||
alembic==1.13.1
|
||||
dotenv==1.1.0
|
5
requirements-dev.txt
Normal file
5
requirements-dev.txt
Normal file
@ -0,0 +1,5 @@
|
||||
pytest==7.4.4
|
||||
pytest-cov==4.1.0
|
||||
black==23.12.1
|
||||
mypy==1.8.0
|
||||
isort==5.13.2
|
3
setup.py
3
setup.py
@ -16,6 +16,7 @@ def read_requirements(filename: str) -> list[str]:
|
||||
common_requires = read_requirements('requirements-common.txt')
|
||||
api_requires = read_requirements('requirements-api.txt')
|
||||
workers_requires = read_requirements('requirements-workers.txt')
|
||||
dev_requires = read_requirements('requirements-dev.txt')
|
||||
|
||||
setup(
|
||||
name="memory",
|
||||
@ -27,5 +28,7 @@ setup(
|
||||
"api": api_requires + common_requires,
|
||||
"workers": workers_requires + common_requires,
|
||||
"common": common_requires,
|
||||
"dev": dev_requires,
|
||||
"all": api_requires + workers_requires + common_requires + dev_requires,
|
||||
},
|
||||
)
|
@ -29,4 +29,9 @@ def get_scoped_session():
|
||||
"""Create a thread-local scoped session factory"""
|
||||
engine = get_engine()
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
return scoped_session(session_factory)
|
||||
return scoped_session(session_factory)
|
||||
|
||||
|
||||
def make_session():
|
||||
with get_scoped_session() as session:
|
||||
yield session
|
||||
|
@ -2,8 +2,8 @@
|
||||
Database models for the knowledge base system.
|
||||
"""
|
||||
from sqlalchemy import (
|
||||
Column, ForeignKey, Integer, BigInteger, Text, DateTime, Boolean, Float,
|
||||
ARRAY, func
|
||||
Column, ForeignKey, Integer, BigInteger, Text, DateTime, Boolean,
|
||||
ARRAY, func, Numeric, CheckConstraint, Index
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, TSVECTOR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
@ -26,6 +26,14 @@ class SourceItem(Base):
|
||||
embed_status = Column(Text, nullable=False, server_default='RAW')
|
||||
byte_length = Column(Integer)
|
||||
mime_type = Column(Text)
|
||||
|
||||
# Add table-level constraint and indexes
|
||||
__table_args__ = (
|
||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
Index('source_modality_idx', 'modality'),
|
||||
Index('source_status_idx', 'embed_status'),
|
||||
Index('source_tags_idx', 'tags', postgresql_using='gin'),
|
||||
)
|
||||
|
||||
|
||||
class MailMessage(Base):
|
||||
@ -41,6 +49,13 @@ class MailMessage(Base):
|
||||
body_raw = Column(Text)
|
||||
attachments = Column(JSONB)
|
||||
tsv = Column(TSVECTOR)
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('mail_sent_idx', 'sent_at'),
|
||||
Index('mail_recipients_idx', 'recipients', postgresql_using='gin'),
|
||||
Index('mail_tsv_idx', 'tsv', postgresql_using='gin'),
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
@ -53,6 +68,11 @@ class ChatMessage(Base):
|
||||
author = Column(Text)
|
||||
sent_at = Column(DateTime(timezone=True))
|
||||
body_raw = Column(Text)
|
||||
|
||||
# Add index
|
||||
__table_args__ = (
|
||||
Index('chat_channel_idx', 'platform', 'channel_id'),
|
||||
)
|
||||
|
||||
|
||||
class GitCommit(Base):
|
||||
@ -68,6 +88,12 @@ class GitCommit(Base):
|
||||
msg_raw = Column(Text)
|
||||
diff_summary = Column(Text)
|
||||
files_changed = Column(ARRAY(Text))
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('git_files_idx', 'files_changed', postgresql_using='gin'),
|
||||
Index('git_date_idx', 'author_date'),
|
||||
)
|
||||
|
||||
|
||||
class Photo(Base):
|
||||
@ -77,10 +103,14 @@ class Photo(Base):
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
file_path = Column(Text)
|
||||
exif_taken_at = Column(DateTime(timezone=True))
|
||||
exif_lat = Column(Float)
|
||||
exif_lon = Column(Float)
|
||||
camera_make = Column(Text)
|
||||
camera_model = Column(Text)
|
||||
exif_lat = Column(Numeric(9, 6))
|
||||
exif_lon = Column(Numeric(9, 6))
|
||||
camera = Column(Text)
|
||||
|
||||
# Add index
|
||||
__table_args__ = (
|
||||
Index('photo_taken_idx', 'exif_taken_at'),
|
||||
)
|
||||
|
||||
|
||||
class BookDoc(Base):
|
||||
@ -91,7 +121,7 @@ class BookDoc(Base):
|
||||
title = Column(Text)
|
||||
author = Column(Text)
|
||||
chapter = Column(Text)
|
||||
published = Column(DateTime)
|
||||
published = Column(DateTime(timezone=True))
|
||||
|
||||
|
||||
class BlogPost(Base):
|
||||
@ -124,4 +154,67 @@ class RssFeed(Base):
|
||||
last_checked_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default='true')
|
||||
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('rss_feeds_active_idx', 'active', 'last_checked_at'),
|
||||
Index('rss_feeds_tags_idx', 'tags', postgresql_using='gin'),
|
||||
)
|
||||
|
||||
|
||||
class EmailAccount(Base):
|
||||
__tablename__ = 'email_accounts'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
name = Column(Text, nullable=False)
|
||||
email_address = Column(Text, nullable=False, unique=True)
|
||||
imap_server = Column(Text, nullable=False)
|
||||
imap_port = Column(Integer, nullable=False, server_default='993')
|
||||
username = Column(Text, nullable=False)
|
||||
password = Column(Text, nullable=False)
|
||||
use_ssl = Column(Boolean, nullable=False, server_default='true')
|
||||
folders = Column(ARRAY(Text), nullable=False, server_default='{}')
|
||||
tags = Column(ARRAY(Text), nullable=False, server_default='{}')
|
||||
last_sync_at = Column(DateTime(timezone=True))
|
||||
active = Column(Boolean, nullable=False, server_default='true')
|
||||
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('email_accounts_address_idx', 'email_address', unique=True),
|
||||
Index('email_accounts_active_idx', 'active', 'last_sync_at'),
|
||||
Index('email_accounts_tags_idx', 'tags', postgresql_using='gin'),
|
||||
)
|
||||
|
||||
|
||||
class GithubItem(Base):
|
||||
__tablename__ = 'github_item'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
|
||||
kind = Column(Text, nullable=False)
|
||||
repo_path = Column(Text, nullable=False)
|
||||
number = Column(Integer)
|
||||
parent_number = Column(Integer)
|
||||
commit_sha = Column(Text)
|
||||
state = Column(Text)
|
||||
title = Column(Text)
|
||||
body_raw = Column(Text)
|
||||
labels = Column(ARRAY(Text))
|
||||
author = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
closed_at = Column(DateTime(timezone=True))
|
||||
merged_at = Column(DateTime(timezone=True))
|
||||
diff_summary = Column(Text)
|
||||
|
||||
payload = Column(JSONB)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("kind IN ('issue', 'pr', 'comment', 'project_card')"),
|
||||
Index('gh_repo_kind_idx', 'repo_path', 'kind'),
|
||||
Index('gh_issue_lookup_idx', 'repo_path', 'kind', 'number'),
|
||||
Index('gh_labels_idx', 'labels', postgresql_using='gin'),
|
||||
)
|
16
src/memory/common/settings.py
Normal file
16
src/memory/common/settings.py
Normal file
@ -0,0 +1,16 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
DB_USER = os.getenv("DB_USER", "kb")
|
||||
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
|
||||
DB_HOST = os.getenv("DB_HOST", "postgres")
|
||||
DB_PORT = os.getenv("DB_PORT", "5432")
|
||||
DB_NAME = os.getenv("DB_NAME", "kb")
|
||||
|
||||
def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT, db=DB_NAME):
|
||||
return f"postgresql://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
DB_URL = os.getenv("DATABASE_URL", make_db_url())
|
362
src/memory/workers/email.py
Normal file
362
src/memory/workers/email.py
Normal file
@ -0,0 +1,362 @@
|
||||
import email
|
||||
import hashlib
|
||||
import imaplib
|
||||
import logging
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from email.utils import parsedate_to_datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_recipients(msg: email.message.Message) -> list[str]:
|
||||
"""
|
||||
Extract email recipients from message headers.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
List of recipient email addresses
|
||||
"""
|
||||
return [
|
||||
recipient
|
||||
for field in ["To", "Cc", "Bcc"]
|
||||
if (field_value := msg.get(field, ""))
|
||||
for r in field_value.split(",")
|
||||
if (recipient := r.strip())
|
||||
]
|
||||
|
||||
|
||||
def extract_date(msg: email.message.Message) -> datetime | None:
|
||||
"""
|
||||
Parse date from email header.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
Parsed datetime or None if parsing failed
|
||||
"""
|
||||
if date_str := msg.get("Date"):
|
||||
try:
|
||||
return parsedate_to_datetime(date_str)
|
||||
except Exception:
|
||||
logger.warning(f"Could not parse date: {date_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_body(msg: email.message.Message) -> str:
|
||||
"""
|
||||
Extract plain text body from email message.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
Plain text body content
|
||||
"""
|
||||
body = ""
|
||||
|
||||
if not msg.is_multipart():
|
||||
try:
|
||||
return msg.get_payload(decode=True).decode(errors='replace')
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding message body: {str(e)}")
|
||||
return ""
|
||||
|
||||
for part in msg.walk():
|
||||
content_type = part.get_content_type()
|
||||
content_disposition = str(part.get("Content-Disposition", ""))
|
||||
|
||||
if content_type == "text/plain" and "attachment" not in content_disposition:
|
||||
try:
|
||||
body += part.get_payload(decode=True).decode(errors='replace') + "\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding message part: {str(e)}")
|
||||
return body
|
||||
|
||||
|
||||
def extract_attachments(msg: email.message.Message) -> list[dict]:
|
||||
"""
|
||||
Extract attachment metadata from email.
|
||||
|
||||
Args:
|
||||
msg: Email message object
|
||||
|
||||
Returns:
|
||||
List of attachment metadata dicts
|
||||
"""
|
||||
if not msg.is_multipart():
|
||||
return []
|
||||
|
||||
attachments = []
|
||||
for part in msg.walk():
|
||||
content_disposition = part.get("Content-Disposition", "")
|
||||
if "attachment" not in content_disposition:
|
||||
continue
|
||||
|
||||
if filename := part.get_filename():
|
||||
attachments.append({
|
||||
"filename": filename,
|
||||
"content_type": part.get_content_type(),
|
||||
"size": len(part.get_payload(decode=True))
|
||||
})
|
||||
|
||||
return attachments
|
||||
|
||||
|
||||
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
|
||||
"""
|
||||
Compute a SHA-256 hash of message content.
|
||||
|
||||
Args:
|
||||
msg_id: Message ID
|
||||
subject: Email subject
|
||||
sender: Sender email
|
||||
body: Message body
|
||||
|
||||
Returns:
|
||||
SHA-256 hash as bytes
|
||||
"""
|
||||
hash_content = (msg_id + subject + sender + body).encode()
|
||||
return hashlib.sha256(hash_content).digest()
|
||||
|
||||
|
||||
def parse_email_message(raw_email: str) -> dict:
|
||||
"""
|
||||
Parse raw email into structured data.
|
||||
|
||||
Args:
|
||||
raw_email: Raw email content as string
|
||||
|
||||
Returns:
|
||||
Dict with parsed email data
|
||||
"""
|
||||
msg = email.message_from_string(raw_email)
|
||||
|
||||
return {
|
||||
"message_id": msg.get("Message-ID", ""),
|
||||
"subject": msg.get("Subject", ""),
|
||||
"sender": msg.get("From", ""),
|
||||
"recipients": extract_recipients(msg),
|
||||
"sent_at": extract_date(msg),
|
||||
"body": extract_body(msg),
|
||||
"attachments": extract_attachments(msg)
|
||||
}
|
||||
|
||||
|
||||
def create_source_item(
|
||||
db_session: Session,
|
||||
message_hash: bytes,
|
||||
account_tags: list[str],
|
||||
raw_email_size: int,
|
||||
) -> SourceItem:
|
||||
"""
|
||||
Create a new source item record.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
message_hash: SHA-256 hash of message
|
||||
account_tags: Tags from the email account
|
||||
raw_email_size: Size of raw email in bytes
|
||||
|
||||
Returns:
|
||||
Newly created SourceItem
|
||||
"""
|
||||
source_item = SourceItem(
|
||||
modality="mail",
|
||||
sha256=message_hash,
|
||||
tags=account_tags,
|
||||
byte_length=raw_email_size,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW"
|
||||
)
|
||||
db_session.add(source_item)
|
||||
db_session.flush()
|
||||
return source_item
|
||||
|
||||
|
||||
def create_mail_message(
|
||||
db_session: Session,
|
||||
source_id: int,
|
||||
parsed_email: dict,
|
||||
folder: str,
|
||||
) -> MailMessage:
|
||||
"""
|
||||
Create a new mail message record.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
source_id: ID of the SourceItem
|
||||
parsed_email: Parsed email data
|
||||
folder: IMAP folder name
|
||||
|
||||
Returns:
|
||||
Newly created MailMessage
|
||||
"""
|
||||
mail_message = MailMessage(
|
||||
source_id=source_id,
|
||||
message_id=parsed_email["message_id"],
|
||||
subject=parsed_email["subject"],
|
||||
sender=parsed_email["sender"],
|
||||
recipients=parsed_email["recipients"],
|
||||
sent_at=parsed_email["sent_at"],
|
||||
body_raw=parsed_email["body"],
|
||||
attachments={"items": parsed_email["attachments"], "folder": folder}
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
return mail_message
|
||||
|
||||
|
||||
def check_message_exists(db_session: Session, message_id: str, message_hash: bytes) -> bool:
|
||||
"""
|
||||
Check if a message already exists in the database.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
message_id: Email message ID
|
||||
message_hash: SHA-256 hash of message
|
||||
|
||||
Returns:
|
||||
True if message exists, False otherwise
|
||||
"""
|
||||
return (
|
||||
# Check by message_id first (faster)
|
||||
message_id and db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first()
|
||||
# Then check by message_hash
|
||||
or db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first() is not None
|
||||
)
|
||||
|
||||
|
||||
def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
|
||||
"""
|
||||
Extract the UID and raw email data from the message data.
|
||||
"""
|
||||
uid_pattern = re.compile(r'UID (\d+)')
|
||||
uid_match = uid_pattern.search(msg_data[0][0].decode('utf-8', errors='replace'))
|
||||
uid = uid_match.group(1) if uid_match else None
|
||||
raw_email = msg_data[0][1]
|
||||
return uid, raw_email
|
||||
|
||||
|
||||
def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> tuple[str, bytes] | None:
|
||||
try:
|
||||
status, msg_data = conn.fetch(uid, '(UID RFC822)')
|
||||
if status != 'OK' or not msg_data or not msg_data[0]:
|
||||
logger.error(f"Error fetching message {uid}")
|
||||
return None
|
||||
|
||||
return extract_email_uid(msg_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message {uid}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def fetch_email_since(
|
||||
conn: imaplib.IMAP4_SSL,
|
||||
folder: str,
|
||||
since_date: datetime
|
||||
) -> list[tuple[str, bytes]]:
|
||||
"""
|
||||
Fetch emails from a folder since a given date.
|
||||
|
||||
Args:
|
||||
conn: IMAP connection
|
||||
folder: Folder name to select
|
||||
since_date: Fetch emails since this date
|
||||
|
||||
Returns:
|
||||
List of tuples with (uid, raw_email)
|
||||
"""
|
||||
try:
|
||||
status, counts = conn.select(folder)
|
||||
if status != 'OK':
|
||||
logger.error(f"Error selecting folder {folder}: {counts}")
|
||||
return []
|
||||
|
||||
date_str = since_date.strftime("%d-%b-%Y")
|
||||
|
||||
status, data = conn.search(None, f'(SINCE "{date_str}")')
|
||||
if status != 'OK':
|
||||
logger.error(f"Error searching folder {folder}: {data}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fetch_email_since for folder {folder}: {str(e)}")
|
||||
return []
|
||||
|
||||
if not data or not data[0]:
|
||||
return []
|
||||
|
||||
return [email for uid in data[0].split() if (email := fetch_email(conn, uid))]
|
||||
|
||||
|
||||
def process_folder(
|
||||
conn: imaplib.IMAP4_SSL,
|
||||
folder: str,
|
||||
account: EmailAccount,
|
||||
since_date: datetime
|
||||
) -> dict:
|
||||
"""
|
||||
Process a single folder from an email account.
|
||||
|
||||
Args:
|
||||
conn: Active IMAP connection
|
||||
folder: Folder name to process
|
||||
account: Email account configuration
|
||||
since_date: Only fetch messages newer than this date
|
||||
|
||||
Returns:
|
||||
Stats dictionary for the folder
|
||||
"""
|
||||
new_messages, errors = 0, 0
|
||||
|
||||
try:
|
||||
emails = fetch_email_since(conn, folder, since_date)
|
||||
|
||||
for uid, raw_email in emails:
|
||||
try:
|
||||
task = process_message.delay(
|
||||
account_id=account.id,
|
||||
message_id=uid,
|
||||
folder=folder,
|
||||
raw_email=raw_email.decode('utf-8', errors='replace')
|
||||
)
|
||||
if task:
|
||||
new_messages += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error queuing message {uid}: {str(e)}")
|
||||
errors += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing folder {folder}: {str(e)}")
|
||||
errors += 1
|
||||
|
||||
return {
|
||||
"messages_found": len(emails),
|
||||
"new_messages": new_messages,
|
||||
"errors": errors
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def imap_connection(account: EmailAccount) -> imaplib.IMAP4_SSL:
|
||||
conn = imaplib.IMAP4_SSL(
|
||||
host=account.imap_server,
|
||||
port=account.imap_port
|
||||
)
|
||||
try:
|
||||
conn.login(account.username, account.password)
|
||||
yield conn
|
||||
finally:
|
||||
# Always try to logout and close the connection
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
Import sub-modules so Celery can register their @app.task decorators.
|
||||
"""
|
||||
from memory.workers.tasks import text, photo, ocr, git, rss, docs # noqa
|
||||
from memory.workers.tasks import text, photo, ocr, git, rss, docs, email # noqa
|
137
src/memory/workers/tasks/email.py
Normal file
137
src/memory/workers/tasks/email.py
Normal file
@ -0,0 +1,137 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import EmailAccount
|
||||
from memory.workers.celery_app import app
|
||||
from memory.workers.email import (
|
||||
check_message_exists,
|
||||
compute_message_hash,
|
||||
create_mail_message,
|
||||
create_source_item,
|
||||
imap_connection,
|
||||
parse_email_message,
|
||||
process_folder,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.task(name="memory.email.process_message")
|
||||
def process_message(
|
||||
account_id: int, message_id: str, folder: str, raw_email: str,
|
||||
) -> int | None:
|
||||
"""
|
||||
Process a single email message and store it in the database.
|
||||
|
||||
Args:
|
||||
account_id: ID of the EmailAccount
|
||||
message_id: UID of the message on the server
|
||||
folder: Folder name where the message is stored
|
||||
raw_email: Raw email content as string
|
||||
|
||||
Returns:
|
||||
source_id if successful, None otherwise
|
||||
"""
|
||||
with make_session() as db:
|
||||
account = db.query(EmailAccount).get(account_id)
|
||||
if not account:
|
||||
logger.error(f"Account {account_id} not found")
|
||||
return None
|
||||
|
||||
parsed_email = parse_email_message(raw_email)
|
||||
|
||||
# Use server-provided message ID if missing
|
||||
if not parsed_email["message_id"]:
|
||||
parsed_email["message_id"] = f"generated-{message_id}"
|
||||
|
||||
message_hash = compute_message_hash(
|
||||
parsed_email["message_id"],
|
||||
parsed_email["subject"],
|
||||
parsed_email["sender"],
|
||||
parsed_email["body"]
|
||||
)
|
||||
|
||||
if check_message_exists(db, parsed_email["message_id"], message_hash):
|
||||
logger.debug(f"Message {parsed_email['message_id']} already exists in database")
|
||||
return None
|
||||
|
||||
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
|
||||
|
||||
create_mail_message(db, source_item.id, parsed_email, folder)
|
||||
|
||||
db.commit()
|
||||
|
||||
# TODO: Queue for embedding once that's implemented
|
||||
|
||||
return source_item.id
|
||||
|
||||
|
||||
@app.task(name="memory.email.sync_account")
|
||||
def sync_account(account_id: int) -> dict:
|
||||
"""
|
||||
Synchronize emails from a specific account.
|
||||
|
||||
Args:
|
||||
account_id: ID of the EmailAccount to sync
|
||||
|
||||
Returns:
|
||||
dict with stats about the sync operation
|
||||
"""
|
||||
with make_session() as db:
|
||||
account = db.query(EmailAccount).filter(EmailAccount.id == account_id).first()
|
||||
if not account or not account.active:
|
||||
logger.warning(f"Account {account_id} not found or inactive")
|
||||
return {"error": "Account not found or inactive"}
|
||||
|
||||
folders_to_process = account.folders or ["INBOX"]
|
||||
since_date = account.last_sync_at or datetime(1970, 1, 1)
|
||||
|
||||
messages_found = 0
|
||||
new_messages = 0
|
||||
errors = 0
|
||||
|
||||
try:
|
||||
with imap_connection(account) as conn:
|
||||
for folder in folders_to_process:
|
||||
folder_stats = process_folder(conn, folder, account, since_date)
|
||||
|
||||
messages_found += folder_stats["messages_found"]
|
||||
new_messages += folder_stats["new_messages"]
|
||||
errors += folder_stats["errors"]
|
||||
|
||||
account.last_sync_at = datetime.now()
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to server {account.imap_server}: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
return {
|
||||
"account": account.email_address,
|
||||
"folders_processed": len(folders_to_process),
|
||||
"messages_found": messages_found,
|
||||
"new_messages": new_messages,
|
||||
"errors": errors
|
||||
}
|
||||
|
||||
|
||||
@app.task(name="memory.email.sync_all_accounts")
|
||||
def sync_all_accounts() -> list[dict]:
|
||||
"""
|
||||
Synchronize all active email accounts.
|
||||
|
||||
Returns:
|
||||
List of task IDs that were scheduled
|
||||
"""
|
||||
with make_session() as db:
|
||||
active_accounts = db.query(EmailAccount).filter(EmailAccount.active).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"account_id": account.id,
|
||||
"email": account.email_address,
|
||||
"task_id": sync_account.delay(account.id).id
|
||||
}
|
||||
for account in active_accounts
|
||||
]
|
129
tests/memory/workers/tasks/conftest.py
Normal file
129
tests/memory/workers/tasks/conftest.py
Normal file
@ -0,0 +1,129 @@
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from memory.common import settings
|
||||
|
||||
|
||||
def get_test_db_name() -> str:
|
||||
return f"test_db_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def create_test_database(test_db_name: str) -> str:
|
||||
"""
|
||||
Create a test database with a unique name.
|
||||
|
||||
Args:
|
||||
test_db_name: Name for the test database
|
||||
|
||||
Returns:
|
||||
URL to the test database
|
||||
"""
|
||||
admin_engine = create_engine(settings.DB_URL)
|
||||
|
||||
# Create a new database
|
||||
with admin_engine.connect() as conn:
|
||||
conn.execute(text("COMMIT")) # Close any open transaction
|
||||
conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}"))
|
||||
conn.execute(text(f"CREATE DATABASE {test_db_name}"))
|
||||
|
||||
admin_engine.dispose()
|
||||
|
||||
return settings.make_db_url(db=test_db_name)
|
||||
|
||||
|
||||
def drop_test_database(test_db_name: str) -> None:
|
||||
"""
|
||||
Drop the test database.
|
||||
|
||||
Args:
|
||||
test_db_name: Name of the test database to drop
|
||||
"""
|
||||
admin_engine = create_engine(settings.DB_URL)
|
||||
|
||||
with admin_engine.connect() as conn:
|
||||
conn.execute(text("COMMIT")) # Close any open transaction
|
||||
conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}"))
|
||||
|
||||
|
||||
def run_alembic_migrations(db_name: str) -> None:
|
||||
"""Run all Alembic migrations on the test database."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
alembic_ini = project_root / "db" / "migrations" / "alembic.ini"
|
||||
|
||||
breakpoint()
|
||||
subprocess.run(
|
||||
["alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)},
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""
|
||||
Create a test database, run migrations, and clean up afterwards.
|
||||
|
||||
Returns:
|
||||
The URL to the test database
|
||||
"""
|
||||
test_db_name = get_test_db_name()
|
||||
|
||||
# Create test database
|
||||
test_db_url = create_test_database(test_db_name)
|
||||
|
||||
try:
|
||||
run_alembic_migrations(test_db_name)
|
||||
|
||||
# Return the URL to the test database
|
||||
yield test_db_url
|
||||
finally:
|
||||
# Clean up - drop the test database
|
||||
drop_test_database(test_db_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine(test_db):
|
||||
"""
|
||||
Create a SQLAlchemy engine connected to the test database.
|
||||
|
||||
Args:
|
||||
test_db: URL to the test database (from the test_db fixture)
|
||||
|
||||
Returns:
|
||||
SQLAlchemy engine
|
||||
"""
|
||||
engine = create_engine(test_db)
|
||||
yield engine
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(db_engine):
|
||||
"""
|
||||
Create a new database session for a test.
|
||||
|
||||
Args:
|
||||
db_engine: SQLAlchemy engine (from the db_engine fixture)
|
||||
|
||||
Returns:
|
||||
SQLAlchemy session
|
||||
"""
|
||||
# Create a new sessionmaker
|
||||
SessionLocal = sessionmaker(bind=db_engine, autocommit=False, autoflush=False)
|
||||
|
||||
# Create a new session
|
||||
session = SessionLocal()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Close and rollback the session after the test is done
|
||||
session.rollback()
|
||||
session.close()
|
325
tests/memory/workers/tasks/test_email.py
Normal file
325
tests/memory/workers/tasks/test_email.py
Normal file
@ -0,0 +1,325 @@
|
||||
import email
|
||||
import email.mime.multipart
|
||||
import email.mime.text
|
||||
import email.mime.base
|
||||
from datetime import datetime
|
||||
from email.utils import formatdate
|
||||
from unittest.mock import ANY
|
||||
import pytest
|
||||
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.workers.email import (
|
||||
compute_message_hash,
|
||||
create_source_item,
|
||||
extract_attachments,
|
||||
extract_body,
|
||||
extract_date,
|
||||
extract_email_uid,
|
||||
extract_recipients,
|
||||
parse_email_message,
|
||||
)
|
||||
|
||||
|
||||
# Use a simple counter to generate unique message IDs without calling make_msgid
|
||||
_msg_id_counter = 0
|
||||
|
||||
|
||||
def _generate_test_message_id():
|
||||
"""Generate a simple message ID for testing without expensive calls"""
|
||||
global _msg_id_counter
|
||||
_msg_id_counter += 1
|
||||
return f"<test-message-{_msg_id_counter}@example.com>"
|
||||
|
||||
|
||||
def create_email_message(
|
||||
subject="Test Subject",
|
||||
from_addr="sender@example.com",
|
||||
to_addrs="recipient@example.com",
|
||||
cc_addrs=None,
|
||||
bcc_addrs=None,
|
||||
date=None,
|
||||
body="Test body content",
|
||||
attachments=None,
|
||||
multipart=True,
|
||||
message_id=None,
|
||||
):
|
||||
"""Helper function to create email.message.Message objects for testing"""
|
||||
if multipart:
|
||||
msg = email.mime.multipart.MIMEMultipart()
|
||||
msg.attach(email.mime.text.MIMEText(body))
|
||||
|
||||
if attachments:
|
||||
for attachment in attachments:
|
||||
attachment_part = email.mime.base.MIMEBase("application", "octet-stream")
|
||||
attachment_part.set_payload(attachment["content"])
|
||||
attachment_part.add_header(
|
||||
"Content-Disposition",
|
||||
f"attachment; filename={attachment['filename']}"
|
||||
)
|
||||
msg.attach(attachment_part)
|
||||
else:
|
||||
msg = email.mime.text.MIMEText(body)
|
||||
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = to_addrs
|
||||
|
||||
if cc_addrs:
|
||||
msg["Cc"] = cc_addrs
|
||||
if bcc_addrs:
|
||||
msg["Bcc"] = bcc_addrs
|
||||
if date:
|
||||
msg["Date"] = formatdate(float(date.timestamp()))
|
||||
if message_id:
|
||||
msg["Message-ID"] = message_id
|
||||
else:
|
||||
msg["Message-ID"] = _generate_test_message_id()
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"to_addr, cc_addr, bcc_addr, expected",
|
||||
[
|
||||
# Single recipient in To field
|
||||
(
|
||||
"recipient@example.com",
|
||||
None,
|
||||
None,
|
||||
["recipient@example.com"]
|
||||
),
|
||||
# Multiple recipients in To field
|
||||
(
|
||||
"recipient1@example.com, recipient2@example.com",
|
||||
None,
|
||||
None,
|
||||
["recipient1@example.com", "recipient2@example.com"]
|
||||
),
|
||||
# To, Cc fields
|
||||
(
|
||||
"recipient@example.com",
|
||||
"cc@example.com",
|
||||
None,
|
||||
["recipient@example.com", "cc@example.com"]
|
||||
),
|
||||
# To, Cc, Bcc fields
|
||||
(
|
||||
"recipient@example.com",
|
||||
"cc@example.com",
|
||||
"bcc@example.com",
|
||||
["recipient@example.com", "cc@example.com", "bcc@example.com"]
|
||||
),
|
||||
# Empty fields
|
||||
(
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
[]
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected):
|
||||
msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr)
|
||||
assert sorted(extract_recipients(msg)) == sorted(expected)
|
||||
|
||||
|
||||
def test_extract_date_missing():
|
||||
msg = create_email_message(date=None)
|
||||
assert extract_date(msg) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"date_str",
|
||||
[
|
||||
"Invalid Date Format",
|
||||
"2023-01-01", # ISO format but not RFC compliant
|
||||
"Monday, Jan 1, 2023", # Descriptive but not RFC compliant
|
||||
"01/01/2023", # Common format but not RFC compliant
|
||||
"", # Empty string
|
||||
]
|
||||
)
|
||||
def test_extract_date_invalid_formats(date_str):
|
||||
msg = create_email_message()
|
||||
msg["Date"] = date_str
|
||||
assert extract_date(msg) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"date_str",
|
||||
[
|
||||
"Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format
|
||||
"01 Jan 2023 12:00:00 +0000", # RFC 822 format
|
||||
"Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name
|
||||
]
|
||||
)
|
||||
def test_extract_date(date_str):
|
||||
msg = create_email_message()
|
||||
msg["Date"] = date_str
|
||||
result = extract_date(msg)
|
||||
|
||||
assert result is not None
|
||||
assert result.year == 2023
|
||||
assert result.month == 1
|
||||
assert result.day == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('multipart', [True, False])
|
||||
def test_extract_body_text_plain(multipart):
|
||||
body_content = "This is a test email body"
|
||||
msg = create_email_message(body=body_content, multipart=multipart)
|
||||
extracted = extract_body(msg)
|
||||
|
||||
# Strip newlines for comparison since multipart emails often add them
|
||||
assert extracted.strip() == body_content.strip()
|
||||
|
||||
|
||||
def test_extract_body_with_attachments():
|
||||
body_content = "This is a test email body"
|
||||
attachments = [
|
||||
{"filename": "test.txt", "content": b"attachment content"}
|
||||
]
|
||||
msg = create_email_message(body=body_content, attachments=attachments)
|
||||
assert body_content in extract_body(msg)
|
||||
|
||||
|
||||
def test_extract_attachments_none():
|
||||
msg = create_email_message(multipart=True)
|
||||
assert extract_attachments(msg) == []
|
||||
|
||||
|
||||
def test_extract_attachments_with_files():
|
||||
attachments = [
|
||||
{"filename": "test1.txt", "content": b"content1"},
|
||||
{"filename": "test2.pdf", "content": b"content2"}
|
||||
]
|
||||
msg = create_email_message(attachments=attachments)
|
||||
|
||||
result = extract_attachments(msg)
|
||||
assert len(result) == 2
|
||||
assert result[0]["filename"] == "test1.txt"
|
||||
assert result[1]["filename"] == "test2.pdf"
|
||||
|
||||
|
||||
def test_extract_attachments_non_multipart():
|
||||
msg = create_email_message(multipart=False)
|
||||
assert extract_attachments(msg) == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msg_id, subject, sender, body, expected",
|
||||
[
|
||||
(
|
||||
"<test@example.com>",
|
||||
"Test Subject",
|
||||
"sender@example.com",
|
||||
"Test body",
|
||||
b"\xf2\xbd" # First two bytes of the actual hash
|
||||
),
|
||||
(
|
||||
"<different@example.com>",
|
||||
"Test Subject",
|
||||
"sender@example.com",
|
||||
"Test body",
|
||||
b"\xa4\x15" # Will be different from the first hash
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_compute_message_hash(msg_id, subject, sender, body, expected):
|
||||
result = compute_message_hash(msg_id, subject, sender, body)
|
||||
|
||||
# Verify it's bytes and correct length for SHA-256 (32 bytes)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) == 32
|
||||
|
||||
# Verify first two bytes match expected
|
||||
assert result[:2] == expected
|
||||
|
||||
|
||||
def test_hash_consistency():
|
||||
args = ("<test@example.com>", "Test Subject", "sender@example.com", "Test body")
|
||||
assert compute_message_hash(*args) == compute_message_hash(*args)
|
||||
|
||||
|
||||
def test_parse_simple_email():
|
||||
test_date = datetime(2023, 1, 1, 12, 0, 0)
|
||||
msg_id = "<test123@example.com>"
|
||||
msg = create_email_message(
|
||||
subject="Test Subject",
|
||||
from_addr="sender@example.com",
|
||||
to_addrs="recipient@example.com",
|
||||
date=test_date,
|
||||
body="Test body content",
|
||||
message_id=msg_id
|
||||
)
|
||||
|
||||
result = parse_email_message(msg.as_string())
|
||||
|
||||
assert result == {
|
||||
"message_id": msg_id,
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": ["recipient@example.com"],
|
||||
"body": "Test body content\n",
|
||||
"attachments": [],
|
||||
"sent_at": ANY,
|
||||
}
|
||||
assert abs(result["sent_at"].timestamp() - test_date.timestamp()) < 86400
|
||||
|
||||
|
||||
def test_parse_email_with_attachments():
|
||||
attachments = [
|
||||
{"filename": "test.txt", "content": b"attachment content"}
|
||||
]
|
||||
msg = create_email_message(attachments=attachments)
|
||||
|
||||
result = parse_email_message(msg.as_string())
|
||||
|
||||
assert len(result["attachments"]) == 1
|
||||
assert result["attachments"][0]["filename"] == "test.txt"
|
||||
|
||||
|
||||
def test_extract_email_uid_valid():
|
||||
msg_data = [(b'1 (UID 12345 RFC822 {1234}', b'raw email content')]
|
||||
uid, raw_email = extract_email_uid(msg_data)
|
||||
|
||||
assert uid == "12345"
|
||||
assert raw_email == b'raw email content'
|
||||
|
||||
|
||||
def test_extract_email_uid_no_match():
|
||||
msg_data = [(b'1 (RFC822 {1234}', b'raw email content')]
|
||||
uid, raw_email = extract_email_uid(msg_data)
|
||||
|
||||
assert uid is None
|
||||
assert raw_email == b'raw email content'
|
||||
|
||||
|
||||
def test_create_source_item(db_session):
|
||||
# Mock data
|
||||
message_hash = b'test_hash_bytes' + bytes(28) # 32 bytes for SHA-256
|
||||
account_tags = ["work", "important"]
|
||||
raw_email_size = 1024
|
||||
|
||||
# Call function
|
||||
source_item = create_source_item(
|
||||
db_session=db_session,
|
||||
message_hash=message_hash,
|
||||
account_tags=account_tags,
|
||||
raw_email_size=raw_email_size
|
||||
)
|
||||
|
||||
# Verify the source item was created correctly
|
||||
assert isinstance(source_item, SourceItem)
|
||||
assert source_item.id is not None
|
||||
assert source_item.modality == "mail"
|
||||
assert source_item.sha256 == message_hash
|
||||
assert source_item.tags == account_tags
|
||||
assert source_item.byte_length == raw_email_size
|
||||
assert source_item.mime_type == "message/rfc822"
|
||||
assert source_item.embed_status == "RAW"
|
||||
|
||||
# Verify it was added to the session
|
||||
db_session.flush()
|
||||
fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one()
|
||||
assert fetched_item is not None
|
||||
assert fetched_item.sha256 == message_hash
|
Loading…
x
Reference in New Issue
Block a user