integration tests for process_message

This commit is contained in:
Daniel O'Connell 2025-04-27 19:48:39 +02:00
parent 03b5c908ee
commit 128f8e3d64
7 changed files with 279 additions and 149 deletions

View File

@ -16,10 +16,12 @@ volumes:
db_data: {} # Postgres db_data: {} # Postgres
qdrant_data: {} # Qdrant qdrant_data: {} # Qdrant
rabbitmq_data: {} # RabbitMQ rabbitmq_data: {} # RabbitMQ
file_storage: {} # File storage
# ------------------------------ X-templates ---------------------------- # ------------------------------ X-templates ----------------------------
x-common-env: &env x-common-env: &env
RABBITMQ_USER: kb RABBITMQ_USER: kb
FILE_STORAGE_DIR: /app/memory_files
TZ: "Etc/UTC" TZ: "Etc/UTC"
@ -42,6 +44,8 @@ x-worker-base: &worker-base
read_only: true read_only: true
tmpfs: [/tmp,/var/tmp] tmpfs: [/tmp,/var/tmp]
cap_drop: [ALL] cap_drop: [ALL]
volumes:
- file_storage:/app/memory_files:rw
logging: logging:
options: {max-size: "10m", max-file: "3"} options: {max-size: "10m", max-file: "3"}

View File

@ -1,4 +1,5 @@
import os import os
import pathlib
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -14,3 +15,10 @@ def make_db_url(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT,
return f"postgresql://{user}:{password}@{host}:{port}/{db}" return f"postgresql://{user}:{password}@{host}:{port}/{db}"
DB_URL = os.getenv("DATABASE_URL", make_db_url()) DB_URL = os.getenv("DATABASE_URL", make_db_url())
FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files"))
FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
# Maximum attachment size to store directly in the database (10MB)
MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024))

View File

@ -3,18 +3,42 @@ import hashlib
import imaplib import imaplib
import logging import logging
import re import re
import uuid
import base64
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
from typing import Generator, Callable from typing import Generator, Callable, TypedDict, Literal
import pathlib
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from memory.common.db.models import EmailAccount, MailMessage, SourceItem from memory.common.db.models import EmailAccount, MailMessage, SourceItem
from memory.common.settings import FILE_STORAGE_DIR, MAX_INLINE_ATTACHMENT_SIZE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Attachment(TypedDict):
filename: str
content_type: str
size: int
content: bytes
path: pathlib.Path
class EmailMessage(TypedDict):
message_id: str
subject: str
sender: str
recipients: list[str]
sent_at: datetime | None
body: str
attachments: list[Attachment]
RawEmailResponse = tuple[Literal["OK", "ERROR"], bytes]
def extract_recipients(msg: email.message.Message) -> list[str]: def extract_recipients(msg: email.message.Message) -> list[str]:
""" """
Extract email recipients from message headers. Extract email recipients from message headers.
@ -83,7 +107,7 @@ def extract_body(msg: email.message.Message) -> str:
return body return body
def extract_attachments(msg: email.message.Message) -> list[dict]: def extract_attachments(msg: email.message.Message) -> list[Attachment]:
""" """
Extract attachment metadata and content from email. Extract attachment metadata and content from email.
@ -117,6 +141,61 @@ def extract_attachments(msg: email.message.Message) -> list[dict]:
return attachments return attachments
def process_attachment(attachment: Attachment, message_id: str) -> Attachment | None:
"""Process an attachment, storing large files on disk and returning metadata.
Args:
attachment: Attachment dictionary with metadata and content
message_id: Email message ID to use in file path generation
Returns:
Processed attachment dictionary with appropriate metadata
"""
if attachment["size"] <= MAX_INLINE_ATTACHMENT_SIZE:
attachment["content"] = base64.b64encode(attachment["content"]).decode('utf-8')
return attachment
safe_message_id = re.sub(r'[<>\s:/\\]', '_', message_id)
unique_id = str(uuid.uuid4())[:8]
safe_filename = re.sub(r'[/\\]', '_', attachment["filename"])
# Create user subdirectory
user_dir = FILE_STORAGE_DIR / safe_message_id
user_dir.mkdir(parents=True, exist_ok=True)
# Final path for the attachment
file_path = user_dir / f"{unique_id}_{safe_filename}"
# Write the file
try:
file_path.write_bytes(attachment["content"])
attachment["path"] = file_path
return attachment
except Exception as e:
logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}")
return None
def process_attachments(attachments: list[Attachment], message_id: str) -> list[Attachment]:
"""
Process email attachments, storing large files on disk and returning metadata.
Args:
attachments: List of attachment dictionaries with metadata and content
message_id: Email message ID to use in file path generation
Returns:
List of processed attachment dictionaries with appropriate metadata
"""
if not attachments:
return []
return [
attachment
for a in attachments if (attachment := process_attachment(a, message_id))
]
def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes: def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> bytes:
""" """
Compute a SHA-256 hash of message content. Compute a SHA-256 hash of message content.
@ -134,7 +213,7 @@ def compute_message_hash(msg_id: str, subject: str, sender: str, body: str) -> b
return hashlib.sha256(hash_content).digest() return hashlib.sha256(hash_content).digest()
def parse_email_message(raw_email: str) -> dict: def parse_email_message(raw_email: str) -> EmailMessage:
""" """
Parse raw email into structured data. Parse raw email into structured data.
@ -146,15 +225,15 @@ def parse_email_message(raw_email: str) -> dict:
""" """
msg = email.message_from_string(raw_email) msg = email.message_from_string(raw_email)
return { return EmailMessage(
"message_id": msg.get("Message-ID", ""), message_id=msg.get("Message-ID", ""),
"subject": msg.get("Subject", ""), subject=msg.get("Subject", ""),
"sender": msg.get("From", ""), sender=msg.get("From", ""),
"recipients": extract_recipients(msg), recipients=extract_recipients(msg),
"sent_at": extract_date(msg), sent_at=extract_date(msg),
"body": extract_body(msg), body=extract_body(msg),
"attachments": extract_attachments(msg) attachments=extract_attachments(msg)
} )
def create_source_item( def create_source_item(
@ -191,7 +270,7 @@ def create_source_item(
def create_mail_message( def create_mail_message(
db_session: Session, db_session: Session,
source_id: int, source_id: int,
parsed_email: dict, parsed_email: EmailMessage,
folder: str, folder: str,
) -> MailMessage: ) -> MailMessage:
""" """
@ -206,6 +285,12 @@ def create_mail_message(
Returns: Returns:
Newly created MailMessage Newly created MailMessage
""" """
processed_attachments = process_attachments(
parsed_email["attachments"],
parsed_email["message_id"]
)
print("processed_attachments", processed_attachments)
mail_message = MailMessage( mail_message = MailMessage(
source_id=source_id, source_id=source_id,
message_id=parsed_email["message_id"], message_id=parsed_email["message_id"],
@ -214,7 +299,7 @@ def create_mail_message(
recipients=parsed_email["recipients"], recipients=parsed_email["recipients"],
sent_at=parsed_email["sent_at"], sent_at=parsed_email["sent_at"],
body_raw=parsed_email["body"], body_raw=parsed_email["body"],
attachments={"items": parsed_email["attachments"], "folder": folder} attachments={"items": processed_attachments, "folder": folder}
) )
db_session.add(mail_message) db_session.add(mail_message)
return mail_message return mail_message
@ -254,7 +339,7 @@ def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
return uid, raw_email return uid, raw_email
def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> tuple[str, bytes] | None: def fetch_email(conn: imaplib.IMAP4_SSL, uid: str) -> RawEmailResponse | None:
try: try:
status, msg_data = conn.fetch(uid, '(UID RFC822)') status, msg_data = conn.fetch(uid, '(UID RFC822)')
if status != 'OK' or not msg_data or not msg_data[0]: if status != 'OK' or not msg_data or not msg_data[0]:
@ -271,7 +356,7 @@ def fetch_email_since(
conn: imaplib.IMAP4_SSL, conn: imaplib.IMAP4_SSL,
folder: str, folder: str,
since_date: datetime = datetime(1970, 1, 1) since_date: datetime = datetime(1970, 1, 1)
) -> list[tuple[str, bytes]]: ) -> list[RawEmailResponse]:
""" """
Fetch emails from a folder since a given date. Fetch emails from a folder since a given date.

View File

@ -34,6 +34,10 @@ def process_message(
Returns: Returns:
source_id if successful, None otherwise source_id if successful, None otherwise
""" """
if not raw_email.strip():
logger.warning(f"Empty email message received for account {account_id}")
return None
with make_session() as db: with make_session() as db:
account = db.query(EmailAccount).get(account_id) account = db.query(EmailAccount).get(account_id)
if not account: if not account:
@ -95,7 +99,7 @@ def sync_account(account_id: int) -> dict:
try: try:
with imap_connection(account) as conn: with imap_connection(account) as conn:
for folder in folders_to_process: for folder in folders_to_process:
folder_stats = process_folder(conn, folder, account, since_date) folder_stats = process_folder(conn, folder, account, since_date, process_message.delay)
messages_found += folder_stats["messages_found"] messages_found += folder_stats["messages_found"]
new_messages += folder_stats["new_messages"] new_messages += folder_stats["new_messages"]

View File

@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
import os import os
import subprocess import subprocess
from unittest.mock import patch
import uuid import uuid
from pathlib import Path from pathlib import Path
@ -41,7 +42,7 @@ def create_test_database(test_db_name: str) -> str:
def drop_test_database(test_db_name: str) -> None: def drop_test_database(test_db_name: str) -> None:
""" """
Drop the test database. Drop the test database after terminating all active connections.
Args: Args:
test_db_name: Name of the test database to drop test_db_name: Name of the test database to drop
@ -50,8 +51,24 @@ def drop_test_database(test_db_name: str) -> None:
with admin_engine.connect() as conn: with admin_engine.connect() as conn:
conn.execute(text("COMMIT")) # Close any open transaction conn.execute(text("COMMIT")) # Close any open transaction
# Terminate all connections to the database
conn.execute(
text(
f"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{test_db_name}'
AND pid <> pg_backend_pid()
"""
)
)
# Drop the database
conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}")) conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}"))
admin_engine.dispose()
def run_alembic_migrations(db_name: str) -> None: def run_alembic_migrations(db_name: str) -> None:
"""Run all Alembic migrations on the test database.""" """Run all Alembic migrations on the test database."""
@ -83,7 +100,8 @@ def test_db():
run_alembic_migrations(test_db_name) run_alembic_migrations(test_db_name)
# Return the URL to the test database # Return the URL to the test database
yield test_db_url with patch("memory.common.settings.DB_URL", test_db_url):
yield test_db_url
finally: finally:
# Clean up - drop the test database # Clean up - drop the test database
drop_test_database(test_db_name) drop_test_database(test_db_name)
@ -173,3 +191,9 @@ def email_provider():
], ],
} }
) )
@pytest.fixture(autouse=True)
def mock_file_storage(tmp_path: Path):
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
yield

View File

@ -1,89 +1,39 @@
import pytest import pytest
from datetime import datetime, timedelta from datetime import datetime, timedelta
from memory.common.db.models import EmailAccount from memory.common.db.models import EmailAccount, MailMessage, SourceItem
from memory.workers.tasks.email import process_message, sync_account, sync_all_accounts from memory.workers.tasks.email import process_message
# from ..email_provider import MockEmailProvider
@pytest.fixture # Test email constants
def sample_emails(): SIMPLE_EMAIL_RAW = """From: alice@example.com
"""Fixture providing a sample set of test emails across different folders.""" To: bob@example.com
now = datetime.now() Subject: Test Email 1
yesterday = now - timedelta(days=1) Message-ID: <test-101@example.com>
last_week = now - timedelta(days=7) Date: Tue, 14 May 2024 10:00:00 +0000
return { This is test email 1"""
"INBOX": [
{ EMAIL_WITH_ATTACHMENT_RAW = """From: eve@example.com
"uid": 101, To: bob@example.com
"flags": "\\Seen", Subject: Email with Attachment
"date": now.strftime("%a, %d %b %Y %H:%M:%S +0000"), Message-ID: <test-302@example.com>
"date_internal": now.strftime("%d-%b-%Y %H:%M:%S +0000"), Date: Tue, 7 May 2024 10:00:00 +0000
"from": "alice@example.com", Content-Type: multipart/mixed; boundary="boundary123"
"to": "bob@example.com",
"subject": "Recent Test Email", --boundary123
"message_id": "<test-101@example.com>", Content-Type: text/plain
"body": "This is a recent test email"
}, This email has an attachment
{
"uid": 102, --boundary123
"flags": "", Content-Type: text/plain; name="test.txt"
"date": yesterday.strftime("%a, %d %b %Y %H:%M:%S +0000"), Content-Disposition: attachment; filename="test.txt"
"date_internal": yesterday.strftime("%d-%b-%Y %H:%M:%S +0000"), Content-Transfer-Encoding: base64
"from": "charlie@example.com",
"to": "bob@example.com", VGhpcyBpcyBhIHRlc3QgYXR0YWNobWVudA==
"subject": "Yesterday's Email",
"message_id": "<test-102@example.com>", --boundary123--"""
"body": "This email was sent yesterday"
}
],
"Sent": [
{
"uid": 201,
"flags": "\\Seen",
"date": yesterday.strftime("%a, %d %b %Y %H:%M:%S +0000"),
"date_internal": yesterday.strftime("%d-%b-%Y %H:%M:%S +0000"),
"from": "bob@example.com",
"to": "alice@example.com",
"subject": "Re: Test Email",
"message_id": "<test-201@example.com>",
"body": "This is a reply to the test email"
}
],
"Archive": [
{
"uid": 301,
"flags": "\\Seen",
"date": last_week.strftime("%a, %d %b %Y %H:%M:%S +0000"),
"date_internal": last_week.strftime("%d-%b-%Y %H:%M:%S +0000"),
"from": "david@example.com",
"to": "bob@example.com",
"subject": "Old Email",
"message_id": "<test-301@example.com>",
"body": "This is an old email from last week"
},
{
"uid": 302,
"flags": "\\Seen",
"date": last_week.strftime("%a, %d %b %Y %H:%M:%S +0000"),
"date_internal": last_week.strftime("%d-%b-%Y %H:%M:%S +0000"),
"from": "eve@example.com",
"to": "bob@example.com",
"subject": "Email with Attachment",
"message_id": "<test-302@example.com>",
"body": "This email has an attachment",
"attachments": [
{
"filename": "test.txt",
"maintype": "text",
"subtype": "plain",
"content": b"This is a test attachment"
}
]
}
]
}
@pytest.fixture @pytest.fixture
@ -104,3 +54,104 @@ def test_email_account(db_session):
db_session.add(account) db_session.add(account)
db_session.commit() db_session.commit()
return account return account
def test_process_simple_email(db_session, test_email_account):
"""Test processing a simple email message."""
source_id = process_message(
account_id=test_email_account.id,
message_id="101",
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
assert source_id is not None
# Check that the source item was created
source_item = db_session.query(SourceItem).filter(SourceItem.id == source_id).first()
assert source_item is not None
assert source_item.modality == "mail"
assert source_item.tags == test_email_account.tags
assert source_item.mime_type == "message/rfc822"
assert source_item.embed_status == "RAW"
# Check that the mail message was created and linked to the source
mail_message = db_session.query(MailMessage).filter(MailMessage.source_id == source_id).first()
assert mail_message is not None
assert mail_message.subject == "Test Email 1"
assert mail_message.sender == "alice@example.com"
assert "bob@example.com" in mail_message.recipients
assert "This is test email 1" in mail_message.body_raw
assert mail_message.attachments.get("folder") == "INBOX"
def test_process_email_with_attachment(db_session, test_email_account):
"""Test processing a message with an attachment."""
source_id = process_message(
account_id=test_email_account.id,
message_id="302",
folder="Archive",
raw_email=EMAIL_WITH_ATTACHMENT_RAW,
)
assert source_id is not None
# Check mail message specifics and attachment
mail_message = db_session.query(MailMessage).filter(MailMessage.source_id == source_id).first()
assert mail_message is not None
assert mail_message.subject == "Email with Attachment"
assert mail_message.sender == "eve@example.com"
assert "This email has an attachment" in mail_message.body_raw
assert mail_message.attachments.get("folder") == "Archive"
# Check attachments were processed
attachment_items = mail_message.attachments.get("items", [])
assert len(attachment_items) > 0
assert attachment_items[0]["filename"] == "test.txt"
assert attachment_items[0]["content_type"] == "text/plain"
def test_process_empty_message(db_session, test_email_account):
"""Test processing an empty/invalid message."""
source_id = process_message(
account_id=test_email_account.id,
message_id="999",
folder="Archive",
raw_email="",
)
assert source_id is None
def test_process_duplicate_message(db_session, test_email_account):
"""Test that duplicate messages are detected and not stored again."""
# First call should succeed and create records
source_id_1 = process_message(
account_id=test_email_account.id,
message_id="101",
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
assert source_id_1 is not None, "First call should return a source_id"
# Count records to verify state before second call
source_count_before = db_session.query(SourceItem).count()
message_count_before = db_session.query(MailMessage).count()
# Second call with same email should detect duplicate and return None
source_id_2 = process_message(
account_id=test_email_account.id,
message_id="101",
folder="INBOX",
raw_email=SIMPLE_EMAIL_RAW,
)
assert source_id_2 is None, "Second call should return None for duplicate message"
# Verify no new records were created
source_count_after = db_session.query(SourceItem).count()
message_count_after = db_session.query(MailMessage).count()
assert source_count_before == source_count_after, "No new SourceItem should be created"
assert message_count_before == message_count_after, "No new MailMessage should be created"

View File

@ -135,49 +135,3 @@ class MockEmailProvider:
# No matching email found # No matching email found
return ('NO', [b'Email not found']) return ('NO', [b'Email not found'])
# def uid(self, command: str, *args) -> tuple[str, list]:
# """
# Handle UID-based commands like SEARCH and FETCH.
# Args:
# command: The IMAP command (SEARCH, FETCH, etc.)
# *args: Additional arguments for the command
# Returns:
# IMAP-style response appropriate for the command
# """
# if not self.current_folder or self.current_folder not in self.emails_by_folder:
# return ('OK', [b''])
# if command == 'SEARCH':
# # For simplicity, return all UIDs in the current folder
# # A real implementation would parse the search criteria
# uids = [str(email["uid"]).encode() for email in self.emails_by_folder[self.current_folder]]
# return ('OK', [b' '.join(uids) if uids else b''])
# elif command == 'FETCH':
# # Parse the UID from the arguments
# uid_arg = args[0].decode() if isinstance(args[0], bytes) else args[0]
# uid = int(uid_arg)
# # Find the email with the matching UID
# for email_data in self.emails_by_folder[self.current_folder]:
# if email_data["uid"] == uid:
# # Generate email content
# email_string = self._generate_email_string(email_data)
# flags = email_data.get("flags", "\\Seen")
# date = email_data.get("date_internal", "01-Jan-2023 00:00:00 +0000")
# # Format the response as expected by the IMAP client
# response = [(
# f'{uid} (UID {uid} FLAGS ({flags}) INTERNALDATE "{date}" RFC822 '
# f'{{{len(email_string)}}}'.encode(),
# email_string.encode()
# )]
# return ('OK', response)
# # No matching email found
# return ('NO', [b'Email not found'])
# return ('NO', [f'Command {command} not implemented'.encode()])