From 14aa6ff9beb87cc1d7beba9a314369cea3f27df7 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 27 Apr 2025 20:20:43 +0200 Subject: [PATCH] safe attachments to disk --- src/memory/workers/email.py | 15 +- tests/memory/workers/test_email.py | 438 ++++++++++++++++++----------- 2 files changed, 284 insertions(+), 169 deletions(-) diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index 85e1ffa..ce55a58 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -13,7 +13,7 @@ import pathlib from sqlalchemy.orm import Session from memory.common.db.models import EmailAccount, MailMessage, SourceItem -from memory.common.settings import FILE_STORAGE_DIR, MAX_INLINE_ATTACHMENT_SIZE +from memory.common import settings logger = logging.getLogger(__name__) @@ -151,16 +151,18 @@ def process_attachment(attachment: Attachment, message_id: str) -> Attachment | Returns: Processed attachment dictionary with appropriate metadata """ - if attachment["size"] <= MAX_INLINE_ATTACHMENT_SIZE: - attachment["content"] = base64.b64encode(attachment["content"]).decode('utf-8') + if not (content := attachment.get("content")): return attachment + if attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE: + return {**attachment, "content": base64.b64encode(content).decode('utf-8')} + safe_message_id = re.sub(r'[<>\s:/\\]', '_', message_id) unique_id = str(uuid.uuid4())[:8] safe_filename = re.sub(r'[/\\]', '_', attachment["filename"]) # Create user subdirectory - user_dir = FILE_STORAGE_DIR / safe_message_id + user_dir = settings.FILE_STORAGE_DIR / safe_message_id user_dir.mkdir(parents=True, exist_ok=True) # Final path for the attachment @@ -168,9 +170,8 @@ def process_attachment(attachment: Attachment, message_id: str) -> Attachment | # Write the file try: - file_path.write_bytes(attachment["content"]) - attachment["path"] = file_path - return attachment + file_path.write_bytes(content) + return {**attachment, "path": file_path} except Exception as e: logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}") return None diff --git a/tests/memory/workers/test_email.py b/tests/memory/workers/test_email.py index 051d9d7..73f78b1 100644 --- a/tests/memory/workers/test_email.py +++ b/tests/memory/workers/test_email.py @@ -2,13 +2,17 @@ import email import email.mime.multipart import email.mime.text import email.mime.base +import base64 +import pathlib +import re + from datetime import datetime from email.utils import formatdate from unittest.mock import ANY, MagicMock, patch import pytest -import imaplib from memory.common.db.models import SourceItem from memory.common.db.models import MailMessage, EmailAccount +from memory.common import settings from memory.workers.email import ( compute_message_hash, create_source_item, @@ -23,8 +27,10 @@ from memory.workers.email import ( fetch_email, fetch_email_since, process_folder, + process_attachment, + process_attachments, ) -from tests.providers.email_provider import MockEmailProvider + # Use a simple counter to generate unique message IDs without calling make_msgid @@ -54,23 +60,25 @@ def create_email_message( if multipart: msg = email.mime.multipart.MIMEMultipart() msg.attach(email.mime.text.MIMEText(body)) - + if attachments: for attachment in attachments: - attachment_part = email.mime.base.MIMEBase("application", "octet-stream") + attachment_part = email.mime.base.MIMEBase( + "application", "octet-stream" + ) attachment_part.set_payload(attachment["content"]) attachment_part.add_header( - "Content-Disposition", - f"attachment; filename={attachment['filename']}" + "Content-Disposition", + f"attachment; filename={attachment['filename']}", ) msg.attach(attachment_part) else: msg = email.mime.text.MIMEText(body) - + msg["Subject"] = subject msg["From"] = from_addr msg["To"] = to_addrs - + if cc_addrs: msg["Cc"] = cc_addrs if bcc_addrs: @@ -81,7 +89,7 @@ def create_email_message( msg["Message-ID"] = message_id else: msg["Message-ID"] = _generate_test_message_id() - + return msg @@ -89,41 +97,31 @@ def create_email_message( "to_addr, cc_addr, bcc_addr, expected", [ # Single recipient in To field - ( - "recipient@example.com", - None, - None, - ["recipient@example.com"] - ), + ("recipient@example.com", None, None, ["recipient@example.com"]), # Multiple recipients in To field ( - "recipient1@example.com, recipient2@example.com", - None, - None, - ["recipient1@example.com", "recipient2@example.com"] + "recipient1@example.com, recipient2@example.com", + None, + None, + ["recipient1@example.com", "recipient2@example.com"], ), # To, Cc fields ( - "recipient@example.com", - "cc@example.com", - None, - ["recipient@example.com", "cc@example.com"] + "recipient@example.com", + "cc@example.com", + None, + ["recipient@example.com", "cc@example.com"], ), # To, Cc, Bcc fields ( - "recipient@example.com", - "cc@example.com", - "bcc@example.com", - ["recipient@example.com", "cc@example.com", "bcc@example.com"] + "recipient@example.com", + "cc@example.com", + "bcc@example.com", + ["recipient@example.com", "cc@example.com", "bcc@example.com"], ), # Empty fields - ( - "", - "", - "", - [] - ), - ] + ("", "", "", []), + ], ) def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected): msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr) @@ -143,7 +141,7 @@ def test_extract_date_missing(): "Monday, Jan 1, 2023", # Descriptive but not RFC compliant "01/01/2023", # Common format but not RFC compliant "", # Empty string - ] + ], ) def test_extract_date_invalid_formats(date_str): msg = create_email_message() @@ -157,34 +155,32 @@ def test_extract_date_invalid_formats(date_str): "Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format "01 Jan 2023 12:00:00 +0000", # RFC 822 format "Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name - ] + ], ) def test_extract_date(date_str): msg = create_email_message() msg["Date"] = date_str result = extract_date(msg) - + assert result is not None assert result.year == 2023 assert result.month == 1 assert result.day == 1 -@pytest.mark.parametrize('multipart', [True, False]) +@pytest.mark.parametrize("multipart", [True, False]) def test_extract_body_text_plain(multipart): body_content = "This is a test email body" msg = create_email_message(body=body_content, multipart=multipart) extracted = extract_body(msg) - + # Strip newlines for comparison since multipart emails often add them assert extracted.strip() == body_content.strip() def test_extract_body_with_attachments(): body_content = "This is a test email body" - attachments = [ - {"filename": "test.txt", "content": b"attachment content"} - ] + attachments = [{"filename": "test.txt", "content": b"attachment content"}] msg = create_email_message(body=body_content, attachments=attachments) assert body_content in extract_body(msg) @@ -197,10 +193,10 @@ def test_extract_attachments_none(): def test_extract_attachments_with_files(): attachments = [ {"filename": "test1.txt", "content": b"content1"}, - {"filename": "test2.pdf", "content": b"content2"} + {"filename": "test2.pdf", "content": b"content2"}, ] msg = create_email_message(attachments=attachments) - + result = extract_attachments(msg) assert len(result) == 2 assert result[0]["filename"] == "test1.txt" @@ -212,32 +208,156 @@ def test_extract_attachments_non_multipart(): assert extract_attachments(msg) == [] +@pytest.mark.parametrize( + "attachment_size, max_inline_size, message_id", + [ + # Small attachment, should be base64 encoded and returned inline + (100, 1000, ""), + # Edge case: exactly at max size, should be base64 encoded + (100, 100, ""), + ], +) +def test_process_attachment_inline(attachment_size, max_inline_size, message_id): + attachment = { + "filename": "test.txt", + "content_type": "text/plain", + "size": attachment_size, + "content": b"a" * attachment_size, + } + + with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size): + result = process_attachment(attachment, message_id) + + assert result is not None + # For inline attachments, content should be base64 encoded string + assert isinstance(result["content"], str) + # Decode the base64 string and compare with the original content + decoded_content = base64.b64decode(result["content"].encode('utf-8')) + assert decoded_content == attachment["content"] + assert "path" not in result + + +@pytest.mark.parametrize( + "attachment_size, max_inline_size, message_id", + [ + # Large attachment, should be saved to disk + (1000, 100, ""), + # Message ID with special characters that need escaping + (1000, 100, ""), + ], +) +def test_process_attachment_disk(attachment_size, max_inline_size, message_id): + attachment = { + "filename": "test.txt", + "content_type": "text/plain", + "size": attachment_size, + "content": b"a" * attachment_size, + } + + with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size): + result = process_attachment(attachment, message_id) + + assert result is not None + # For disk-stored attachments, content should not be modified and path should be set + assert "path" in result + assert isinstance(result["path"], pathlib.Path) + + # Verify the path contains safe message ID + safe_message_id = re.sub(r"[<>\s:/\\]", "_", message_id) + assert safe_message_id in str(result["path"]) + + +def test_process_attachment_write_error(): + # Create test attachment + attachment = { + "filename": "test_error.txt", + "content_type": "text/plain", + "size": 100, + "content": b"a" * 100, + } + + # Mock write_bytes to raise an exception + def mock_write_bytes(self, content): + raise IOError("Test write error") + + with ( + patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10), + patch.object(pathlib.Path, "write_bytes", mock_write_bytes), + ): + assert process_attachment(attachment, "") is None + + +def test_process_attachments_empty(): + assert process_attachments([], "") == [] + + +def test_process_attachments_mixed(): + # Create test attachments + attachments = [ + # Small attachment - should be kept inline + { + "filename": "small.txt", + "content_type": "text/plain", + "size": 20, + "content": b"a" * 20, + }, + # Large attachment - should be stored on disk + { + "filename": "large.txt", + "content_type": "text/plain", + "size": 100, + "content": b"b" * 100, + }, + # Another small attachment + { + "filename": "another_small.txt", + "content_type": "text/plain", + "size": 30, + "content": b"c" * 30, + }, + ] + + with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50): + # Process attachments + results = process_attachments(attachments, "") + + # Verify we have all attachments processed + assert len(results) == 3 + + # Verify small attachments are base64 encoded + assert isinstance(results[0]["content"], str) + assert isinstance(results[2]["content"], str) + + # Verify large attachment has a path + assert "path" in results[1] + + @pytest.mark.parametrize( "msg_id, subject, sender, body, expected", [ ( - "", - "Test Subject", - "sender@example.com", - "Test body", - b"\xf2\xbd" # First two bytes of the actual hash + "", + "Test Subject", + "sender@example.com", + "Test body", + b"\xf2\xbd", # First two bytes of the actual hash ), ( - "", - "Test Subject", - "sender@example.com", + "", + "Test Subject", + "sender@example.com", "Test body", - b"\xa4\x15" # Will be different from the first hash + b"\xa4\x15", # Will be different from the first hash ), - ] + ], ) def test_compute_message_hash(msg_id, subject, sender, body, expected): result = compute_message_hash(msg_id, subject, sender, body) - + # Verify it's bytes and correct length for SHA-256 (32 bytes) assert isinstance(result, bytes) assert len(result) == 32 - + # Verify first two bytes match expected assert result[:2] == expected @@ -256,11 +376,11 @@ def test_parse_simple_email(): to_addrs="recipient@example.com", date=test_date, body="Test body content", - message_id=msg_id + message_id=msg_id, ) - + result = parse_email_message(msg.as_string()) - + assert result == { "message_id": msg_id, "subject": "Test Subject", @@ -274,47 +394,45 @@ def test_parse_simple_email(): def test_parse_email_with_attachments(): - attachments = [ - {"filename": "test.txt", "content": b"attachment content"} - ] + attachments = [{"filename": "test.txt", "content": b"attachment content"}] msg = create_email_message(attachments=attachments) - + result = parse_email_message(msg.as_string()) - + assert len(result["attachments"]) == 1 assert result["attachments"][0]["filename"] == "test.txt" def test_extract_email_uid_valid(): - msg_data = [(b'1 (UID 12345 RFC822 {1234}', b'raw email content')] + msg_data = [(b"1 (UID 12345 RFC822 {1234}", b"raw email content")] uid, raw_email = extract_email_uid(msg_data) - + assert uid == "12345" - assert raw_email == b'raw email content' + assert raw_email == b"raw email content" def test_extract_email_uid_no_match(): - msg_data = [(b'1 (RFC822 {1234}', b'raw email content')] + msg_data = [(b"1 (RFC822 {1234}", b"raw email content")] uid, raw_email = extract_email_uid(msg_data) - + assert uid is None - assert raw_email == b'raw email content' + assert raw_email == b"raw email content" def test_create_source_item(db_session): # Mock data - message_hash = b'test_hash_bytes' + bytes(28) # 32 bytes for SHA-256 + message_hash = b"test_hash_bytes" + bytes(28) # 32 bytes for SHA-256 account_tags = ["work", "important"] raw_email_size = 1024 - + # Call function source_item = create_source_item( db_session=db_session, message_hash=message_hash, account_tags=account_tags, - raw_email_size=raw_email_size + raw_email_size=raw_email_size, ) - + # Verify the source item was created correctly assert isinstance(source_item, SourceItem) assert source_item.id is not None @@ -324,7 +442,7 @@ def test_create_source_item(db_session): assert source_item.byte_length == raw_email_size assert source_item.mime_type == "message/rfc822" assert source_item.embed_status == "RAW" - + # Verify it was added to the session db_session.flush() fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one() @@ -339,66 +457,64 @@ def test_create_source_item(db_session): ( lambda db: ( # First create source_item to satisfy foreign key constraint - db.add(SourceItem( - id=1, - modality="mail", - sha256=b'some_hash_bytes' + bytes(28), - tags=["test"], - byte_length=100, - mime_type="message/rfc822", - embed_status="RAW" - )), + db.add( + SourceItem( + id=1, + modality="mail", + sha256=b"some_hash_bytes" + bytes(28), + tags=["test"], + byte_length=100, + mime_type="message/rfc822", + embed_status="RAW", + ) + ), db.flush(), # Then create mail_message - db.add(MailMessage( - source_id=1, - message_id="", - subject="Test", - sender="test@example.com", - recipients=["recipient@example.com"], - body_raw="Test body" - )) + db.add( + MailMessage( + source_id=1, + message_id="", + subject="Test", + sender="test@example.com", + recipients=["recipient@example.com"], + body_raw="Test body", + ) + ), ), "", b"unmatched_hash", - True + True, ), # Test by non-existent message ID - ( - lambda db: None, - "", - b"unmatched_hash", - False - ), + (lambda db: None, "", b"unmatched_hash", False), # Test by hash ( - lambda db: db.add(SourceItem( - modality="mail", - sha256=b'test_hash_bytes' + bytes(28), - tags=["test"], - byte_length=100, - mime_type="message/rfc822", - embed_status="RAW" - )), + lambda db: db.add( + SourceItem( + modality="mail", + sha256=b"test_hash_bytes" + bytes(28), + tags=["test"], + byte_length=100, + mime_type="message/rfc822", + embed_status="RAW", + ) + ), "", - b'test_hash_bytes' + bytes(28), - True + b"test_hash_bytes" + bytes(28), + True, ), # Test by non-existent hash - ( - lambda db: None, - "", - b'different_hash_' + bytes(28), - False - ), - ] + (lambda db: None, "", b"different_hash_" + bytes(28), False), + ], ) -def test_check_message_exists(db_session, setup_db, message_id, message_hash, expected_exists): +def test_check_message_exists( + db_session, setup_db, message_id, message_hash, expected_exists +): # Setup test data if setup_db: setup_db(db_session) db_session.flush() - + # Test the function assert check_message_exists(db_session, message_id, message_hash) == expected_exists @@ -412,18 +528,20 @@ def test_create_mail_message(db_session): "recipients": ["recipient@example.com"], "sent_at": datetime(2023, 1, 1, 12, 0, 0), "body": "Test body content", - "attachments": [{"filename": "test.txt", "content_type": "text/plain", "size": 100}] + "attachments": [ + {"filename": "test.txt", "content_type": "text/plain", "size": 100} + ], } folder = "INBOX" - + # Call function mail_message = create_mail_message( db_session=db_session, source_id=source_id, parsed_email=parsed_email, - folder=folder + folder=folder, ) - + # Verify the mail message was created correctly assert isinstance(mail_message, MailMessage) assert mail_message.source_id == source_id @@ -433,22 +551,25 @@ def test_create_mail_message(db_session): assert mail_message.recipients == parsed_email["recipients"] assert mail_message.sent_at == parsed_email["sent_at"] assert mail_message.body_raw == parsed_email["body"] - assert mail_message.attachments == {"items": parsed_email["attachments"], "folder": folder} + assert mail_message.attachments == { + "items": parsed_email["attachments"], + "folder": folder, + } def test_fetch_email(email_provider): # Configure the provider with sample emails email_provider.select("INBOX") - + # Test fetching an existing email result = fetch_email(email_provider, "101") - + # Verify result contains the expected UID and content assert result is not None uid, content = result assert uid == "101" assert b"This is test email 1" in content - + # Test fetching a non-existent email result = fetch_email(email_provider, "999") assert result is None @@ -457,63 +578,56 @@ def test_fetch_email(email_provider): def test_fetch_email_since(email_provider): # Fetch emails from INBOX folder result = fetch_email_since(email_provider, "INBOX", datetime(1970, 1, 1)) - + # Verify we got the expected number of emails assert len(result) == 2 - + # Verify content of fetched emails uids = sorted([uid for uid, _ in result]) assert uids == ["101", "102"] - + # Test with a folder that doesn't exist - result = fetch_email_since(email_provider, "NonExistentFolder", datetime(1970, 1, 1)) + result = fetch_email_since( + email_provider, "NonExistentFolder", datetime(1970, 1, 1) + ) assert result == [] -@patch('memory.workers.tasks.email.process_message.delay') -def test_process_folder(mock_process_message_delay, email_provider): +def test_process_folder(email_provider): account = MagicMock(spec=EmailAccount) account.id = 123 account.tags = ["test"] - - results = process_folder(email_provider, "INBOX", account, datetime(1970, 1, 1), mock_process_message_delay) - - assert results == { - "messages_found": 2, - "new_messages": 2, - "errors": 0 - } - -@patch('memory.workers.tasks.email.process_message.delay') -def test_process_folder_no_emails(mock_process_message_delay, email_provider): + results = process_folder( + email_provider, "INBOX", account, datetime(1970, 1, 1), MagicMock() + ) + + assert results == {"messages_found": 2, "new_messages": 2, "errors": 0} + + +def test_process_folder_no_emails(email_provider): account = MagicMock(spec=EmailAccount) account.id = 123 - email_provider.search = MagicMock(return_value=("OK", [b''])) - - result = process_folder(email_provider, "Empty", account, datetime(1970, 1, 1), mock_process_message_delay) - assert result == { - "messages_found": 0, - "new_messages": 0, - "errors": 0 - } + email_provider.search = MagicMock(return_value=("OK", [b""])) + + result = process_folder( + email_provider, "Empty", account, datetime(1970, 1, 1), MagicMock() + ) + assert result == {"messages_found": 0, "new_messages": 0, "errors": 0} def test_process_folder_error(email_provider): account = MagicMock(spec=EmailAccount) account.id = 123 - + mock_processor = MagicMock() - + def raise_exception(*args): raise Exception("Test error") - - email_provider.search = raise_exception - - result = process_folder(email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor) - assert result == { - "messages_found": 0, - "new_messages": 0, - "errors": 0 - } + email_provider.search = raise_exception + + result = process_folder( + email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor + ) + assert result == {"messages_found": 0, "new_messages": 0, "errors": 0}