safe attachments to disk

This commit is contained in:
Daniel O'Connell 2025-04-27 20:20:43 +02:00
parent 128f8e3d64
commit 14aa6ff9be
2 changed files with 284 additions and 169 deletions

View File

@ -13,7 +13,7 @@ import pathlib
from sqlalchemy.orm import Session
from memory.common.db.models import EmailAccount, MailMessage, SourceItem
from memory.common.settings import FILE_STORAGE_DIR, MAX_INLINE_ATTACHMENT_SIZE
from memory.common import settings
logger = logging.getLogger(__name__)
@ -151,16 +151,18 @@ def process_attachment(attachment: Attachment, message_id: str) -> Attachment |
Returns:
Processed attachment dictionary with appropriate metadata
"""
if attachment["size"] <= MAX_INLINE_ATTACHMENT_SIZE:
attachment["content"] = base64.b64encode(attachment["content"]).decode('utf-8')
if not (content := attachment.get("content")):
return attachment
if attachment["size"] <= settings.MAX_INLINE_ATTACHMENT_SIZE:
return {**attachment, "content": base64.b64encode(content).decode('utf-8')}
safe_message_id = re.sub(r'[<>\s:/\\]', '_', message_id)
unique_id = str(uuid.uuid4())[:8]
safe_filename = re.sub(r'[/\\]', '_', attachment["filename"])
# Create user subdirectory
user_dir = FILE_STORAGE_DIR / safe_message_id
user_dir = settings.FILE_STORAGE_DIR / safe_message_id
user_dir.mkdir(parents=True, exist_ok=True)
# Final path for the attachment
@ -168,9 +170,8 @@ def process_attachment(attachment: Attachment, message_id: str) -> Attachment |
# Write the file
try:
file_path.write_bytes(attachment["content"])
attachment["path"] = file_path
return attachment
file_path.write_bytes(content)
return {**attachment, "path": file_path}
except Exception as e:
logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}")
return None

View File

@ -2,13 +2,17 @@ import email
import email.mime.multipart
import email.mime.text
import email.mime.base
import base64
import pathlib
import re
from datetime import datetime
from email.utils import formatdate
from unittest.mock import ANY, MagicMock, patch
import pytest
import imaplib
from memory.common.db.models import SourceItem
from memory.common.db.models import MailMessage, EmailAccount
from memory.common import settings
from memory.workers.email import (
compute_message_hash,
create_source_item,
@ -23,8 +27,10 @@ from memory.workers.email import (
fetch_email,
fetch_email_since,
process_folder,
process_attachment,
process_attachments,
)
from tests.providers.email_provider import MockEmailProvider
# Use a simple counter to generate unique message IDs without calling make_msgid
@ -54,23 +60,25 @@ def create_email_message(
if multipart:
msg = email.mime.multipart.MIMEMultipart()
msg.attach(email.mime.text.MIMEText(body))
if attachments:
for attachment in attachments:
attachment_part = email.mime.base.MIMEBase("application", "octet-stream")
attachment_part = email.mime.base.MIMEBase(
"application", "octet-stream"
)
attachment_part.set_payload(attachment["content"])
attachment_part.add_header(
"Content-Disposition",
f"attachment; filename={attachment['filename']}"
"Content-Disposition",
f"attachment; filename={attachment['filename']}",
)
msg.attach(attachment_part)
else:
msg = email.mime.text.MIMEText(body)
msg["Subject"] = subject
msg["From"] = from_addr
msg["To"] = to_addrs
if cc_addrs:
msg["Cc"] = cc_addrs
if bcc_addrs:
@ -81,7 +89,7 @@ def create_email_message(
msg["Message-ID"] = message_id
else:
msg["Message-ID"] = _generate_test_message_id()
return msg
@ -89,41 +97,31 @@ def create_email_message(
"to_addr, cc_addr, bcc_addr, expected",
[
# Single recipient in To field
(
"recipient@example.com",
None,
None,
["recipient@example.com"]
),
("recipient@example.com", None, None, ["recipient@example.com"]),
# Multiple recipients in To field
(
"recipient1@example.com, recipient2@example.com",
None,
None,
["recipient1@example.com", "recipient2@example.com"]
"recipient1@example.com, recipient2@example.com",
None,
None,
["recipient1@example.com", "recipient2@example.com"],
),
# To, Cc fields
(
"recipient@example.com",
"cc@example.com",
None,
["recipient@example.com", "cc@example.com"]
"recipient@example.com",
"cc@example.com",
None,
["recipient@example.com", "cc@example.com"],
),
# To, Cc, Bcc fields
(
"recipient@example.com",
"cc@example.com",
"bcc@example.com",
["recipient@example.com", "cc@example.com", "bcc@example.com"]
"recipient@example.com",
"cc@example.com",
"bcc@example.com",
["recipient@example.com", "cc@example.com", "bcc@example.com"],
),
# Empty fields
(
"",
"",
"",
[]
),
]
("", "", "", []),
],
)
def test_extract_recipients(to_addr, cc_addr, bcc_addr, expected):
msg = create_email_message(to_addrs=to_addr, cc_addrs=cc_addr, bcc_addrs=bcc_addr)
@ -143,7 +141,7 @@ def test_extract_date_missing():
"Monday, Jan 1, 2023", # Descriptive but not RFC compliant
"01/01/2023", # Common format but not RFC compliant
"", # Empty string
]
],
)
def test_extract_date_invalid_formats(date_str):
msg = create_email_message()
@ -157,34 +155,32 @@ def test_extract_date_invalid_formats(date_str):
"Mon, 01 Jan 2023 12:00:00 +0000", # RFC 5322 format
"01 Jan 2023 12:00:00 +0000", # RFC 822 format
"Mon, 01 Jan 2023 12:00:00 GMT", # With timezone name
]
],
)
def test_extract_date(date_str):
msg = create_email_message()
msg["Date"] = date_str
result = extract_date(msg)
assert result is not None
assert result.year == 2023
assert result.month == 1
assert result.day == 1
@pytest.mark.parametrize('multipart', [True, False])
@pytest.mark.parametrize("multipart", [True, False])
def test_extract_body_text_plain(multipart):
body_content = "This is a test email body"
msg = create_email_message(body=body_content, multipart=multipart)
extracted = extract_body(msg)
# Strip newlines for comparison since multipart emails often add them
assert extracted.strip() == body_content.strip()
def test_extract_body_with_attachments():
body_content = "This is a test email body"
attachments = [
{"filename": "test.txt", "content": b"attachment content"}
]
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
msg = create_email_message(body=body_content, attachments=attachments)
assert body_content in extract_body(msg)
@ -197,10 +193,10 @@ def test_extract_attachments_none():
def test_extract_attachments_with_files():
attachments = [
{"filename": "test1.txt", "content": b"content1"},
{"filename": "test2.pdf", "content": b"content2"}
{"filename": "test2.pdf", "content": b"content2"},
]
msg = create_email_message(attachments=attachments)
result = extract_attachments(msg)
assert len(result) == 2
assert result[0]["filename"] == "test1.txt"
@ -212,32 +208,156 @@ def test_extract_attachments_non_multipart():
assert extract_attachments(msg) == []
@pytest.mark.parametrize(
"attachment_size, max_inline_size, message_id",
[
# Small attachment, should be base64 encoded and returned inline
(100, 1000, "<test@example.com>"),
# Edge case: exactly at max size, should be base64 encoded
(100, 100, "<test@example.com>"),
],
)
def test_process_attachment_inline(attachment_size, max_inline_size, message_id):
attachment = {
"filename": "test.txt",
"content_type": "text/plain",
"size": attachment_size,
"content": b"a" * attachment_size,
}
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
result = process_attachment(attachment, message_id)
assert result is not None
# For inline attachments, content should be base64 encoded string
assert isinstance(result["content"], str)
# Decode the base64 string and compare with the original content
decoded_content = base64.b64decode(result["content"].encode('utf-8'))
assert decoded_content == attachment["content"]
assert "path" not in result
@pytest.mark.parametrize(
"attachment_size, max_inline_size, message_id",
[
# Large attachment, should be saved to disk
(1000, 100, "<test@example.com>"),
# Message ID with special characters that need escaping
(1000, 100, "<test/with:special\\chars>"),
],
)
def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
attachment = {
"filename": "test.txt",
"content_type": "text/plain",
"size": attachment_size,
"content": b"a" * attachment_size,
}
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", max_inline_size):
result = process_attachment(attachment, message_id)
assert result is not None
# For disk-stored attachments, content should not be modified and path should be set
assert "path" in result
assert isinstance(result["path"], pathlib.Path)
# Verify the path contains safe message ID
safe_message_id = re.sub(r"[<>\s:/\\]", "_", message_id)
assert safe_message_id in str(result["path"])
def test_process_attachment_write_error():
# Create test attachment
attachment = {
"filename": "test_error.txt",
"content_type": "text/plain",
"size": 100,
"content": b"a" * 100,
}
# Mock write_bytes to raise an exception
def mock_write_bytes(self, content):
raise IOError("Test write error")
with (
patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 10),
patch.object(pathlib.Path, "write_bytes", mock_write_bytes),
):
assert process_attachment(attachment, "<test@example.com>") is None
def test_process_attachments_empty():
assert process_attachments([], "<test@example.com>") == []
def test_process_attachments_mixed():
# Create test attachments
attachments = [
# Small attachment - should be kept inline
{
"filename": "small.txt",
"content_type": "text/plain",
"size": 20,
"content": b"a" * 20,
},
# Large attachment - should be stored on disk
{
"filename": "large.txt",
"content_type": "text/plain",
"size": 100,
"content": b"b" * 100,
},
# Another small attachment
{
"filename": "another_small.txt",
"content_type": "text/plain",
"size": 30,
"content": b"c" * 30,
},
]
with patch.object(settings, "MAX_INLINE_ATTACHMENT_SIZE", 50):
# Process attachments
results = process_attachments(attachments, "<test@example.com>")
# Verify we have all attachments processed
assert len(results) == 3
# Verify small attachments are base64 encoded
assert isinstance(results[0]["content"], str)
assert isinstance(results[2]["content"], str)
# Verify large attachment has a path
assert "path" in results[1]
@pytest.mark.parametrize(
"msg_id, subject, sender, body, expected",
[
(
"<test@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xf2\xbd" # First two bytes of the actual hash
"<test@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xf2\xbd", # First two bytes of the actual hash
),
(
"<different@example.com>",
"Test Subject",
"sender@example.com",
"<different@example.com>",
"Test Subject",
"sender@example.com",
"Test body",
b"\xa4\x15" # Will be different from the first hash
b"\xa4\x15", # Will be different from the first hash
),
]
],
)
def test_compute_message_hash(msg_id, subject, sender, body, expected):
result = compute_message_hash(msg_id, subject, sender, body)
# Verify it's bytes and correct length for SHA-256 (32 bytes)
assert isinstance(result, bytes)
assert len(result) == 32
# Verify first two bytes match expected
assert result[:2] == expected
@ -256,11 +376,11 @@ def test_parse_simple_email():
to_addrs="recipient@example.com",
date=test_date,
body="Test body content",
message_id=msg_id
message_id=msg_id,
)
result = parse_email_message(msg.as_string())
assert result == {
"message_id": msg_id,
"subject": "Test Subject",
@ -274,47 +394,45 @@ def test_parse_simple_email():
def test_parse_email_with_attachments():
attachments = [
{"filename": "test.txt", "content": b"attachment content"}
]
attachments = [{"filename": "test.txt", "content": b"attachment content"}]
msg = create_email_message(attachments=attachments)
result = parse_email_message(msg.as_string())
assert len(result["attachments"]) == 1
assert result["attachments"][0]["filename"] == "test.txt"
def test_extract_email_uid_valid():
msg_data = [(b'1 (UID 12345 RFC822 {1234}', b'raw email content')]
msg_data = [(b"1 (UID 12345 RFC822 {1234}", b"raw email content")]
uid, raw_email = extract_email_uid(msg_data)
assert uid == "12345"
assert raw_email == b'raw email content'
assert raw_email == b"raw email content"
def test_extract_email_uid_no_match():
msg_data = [(b'1 (RFC822 {1234}', b'raw email content')]
msg_data = [(b"1 (RFC822 {1234}", b"raw email content")]
uid, raw_email = extract_email_uid(msg_data)
assert uid is None
assert raw_email == b'raw email content'
assert raw_email == b"raw email content"
def test_create_source_item(db_session):
# Mock data
message_hash = b'test_hash_bytes' + bytes(28) # 32 bytes for SHA-256
message_hash = b"test_hash_bytes" + bytes(28) # 32 bytes for SHA-256
account_tags = ["work", "important"]
raw_email_size = 1024
# Call function
source_item = create_source_item(
db_session=db_session,
message_hash=message_hash,
account_tags=account_tags,
raw_email_size=raw_email_size
raw_email_size=raw_email_size,
)
# Verify the source item was created correctly
assert isinstance(source_item, SourceItem)
assert source_item.id is not None
@ -324,7 +442,7 @@ def test_create_source_item(db_session):
assert source_item.byte_length == raw_email_size
assert source_item.mime_type == "message/rfc822"
assert source_item.embed_status == "RAW"
# Verify it was added to the session
db_session.flush()
fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one()
@ -339,66 +457,64 @@ def test_create_source_item(db_session):
(
lambda db: (
# First create source_item to satisfy foreign key constraint
db.add(SourceItem(
id=1,
modality="mail",
sha256=b'some_hash_bytes' + bytes(28),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW"
)),
db.add(
SourceItem(
id=1,
modality="mail",
sha256=b"some_hash_bytes" + bytes(28),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
)
),
db.flush(),
# Then create mail_message
db.add(MailMessage(
source_id=1,
message_id="<test@example.com>",
subject="Test",
sender="test@example.com",
recipients=["recipient@example.com"],
body_raw="Test body"
))
db.add(
MailMessage(
source_id=1,
message_id="<test@example.com>",
subject="Test",
sender="test@example.com",
recipients=["recipient@example.com"],
body_raw="Test body",
)
),
),
"<test@example.com>",
b"unmatched_hash",
True
True,
),
# Test by non-existent message ID
(
lambda db: None,
"<nonexistent@example.com>",
b"unmatched_hash",
False
),
(lambda db: None, "<nonexistent@example.com>", b"unmatched_hash", False),
# Test by hash
(
lambda db: db.add(SourceItem(
modality="mail",
sha256=b'test_hash_bytes' + bytes(28),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW"
)),
lambda db: db.add(
SourceItem(
modality="mail",
sha256=b"test_hash_bytes" + bytes(28),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
)
),
"",
b'test_hash_bytes' + bytes(28),
True
b"test_hash_bytes" + bytes(28),
True,
),
# Test by non-existent hash
(
lambda db: None,
"",
b'different_hash_' + bytes(28),
False
),
]
(lambda db: None, "", b"different_hash_" + bytes(28), False),
],
)
def test_check_message_exists(db_session, setup_db, message_id, message_hash, expected_exists):
def test_check_message_exists(
db_session, setup_db, message_id, message_hash, expected_exists
):
# Setup test data
if setup_db:
setup_db(db_session)
db_session.flush()
# Test the function
assert check_message_exists(db_session, message_id, message_hash) == expected_exists
@ -412,18 +528,20 @@ def test_create_mail_message(db_session):
"recipients": ["recipient@example.com"],
"sent_at": datetime(2023, 1, 1, 12, 0, 0),
"body": "Test body content",
"attachments": [{"filename": "test.txt", "content_type": "text/plain", "size": 100}]
"attachments": [
{"filename": "test.txt", "content_type": "text/plain", "size": 100}
],
}
folder = "INBOX"
# Call function
mail_message = create_mail_message(
db_session=db_session,
source_id=source_id,
parsed_email=parsed_email,
folder=folder
folder=folder,
)
# Verify the mail message was created correctly
assert isinstance(mail_message, MailMessage)
assert mail_message.source_id == source_id
@ -433,22 +551,25 @@ def test_create_mail_message(db_session):
assert mail_message.recipients == parsed_email["recipients"]
assert mail_message.sent_at == parsed_email["sent_at"]
assert mail_message.body_raw == parsed_email["body"]
assert mail_message.attachments == {"items": parsed_email["attachments"], "folder": folder}
assert mail_message.attachments == {
"items": parsed_email["attachments"],
"folder": folder,
}
def test_fetch_email(email_provider):
# Configure the provider with sample emails
email_provider.select("INBOX")
# Test fetching an existing email
result = fetch_email(email_provider, "101")
# Verify result contains the expected UID and content
assert result is not None
uid, content = result
assert uid == "101"
assert b"This is test email 1" in content
# Test fetching a non-existent email
result = fetch_email(email_provider, "999")
assert result is None
@ -457,63 +578,56 @@ def test_fetch_email(email_provider):
def test_fetch_email_since(email_provider):
# Fetch emails from INBOX folder
result = fetch_email_since(email_provider, "INBOX", datetime(1970, 1, 1))
# Verify we got the expected number of emails
assert len(result) == 2
# Verify content of fetched emails
uids = sorted([uid for uid, _ in result])
assert uids == ["101", "102"]
# Test with a folder that doesn't exist
result = fetch_email_since(email_provider, "NonExistentFolder", datetime(1970, 1, 1))
result = fetch_email_since(
email_provider, "NonExistentFolder", datetime(1970, 1, 1)
)
assert result == []
@patch('memory.workers.tasks.email.process_message.delay')
def test_process_folder(mock_process_message_delay, email_provider):
def test_process_folder(email_provider):
account = MagicMock(spec=EmailAccount)
account.id = 123
account.tags = ["test"]
results = process_folder(email_provider, "INBOX", account, datetime(1970, 1, 1), mock_process_message_delay)
assert results == {
"messages_found": 2,
"new_messages": 2,
"errors": 0
}
@patch('memory.workers.tasks.email.process_message.delay')
def test_process_folder_no_emails(mock_process_message_delay, email_provider):
results = process_folder(
email_provider, "INBOX", account, datetime(1970, 1, 1), MagicMock()
)
assert results == {"messages_found": 2, "new_messages": 2, "errors": 0}
def test_process_folder_no_emails(email_provider):
account = MagicMock(spec=EmailAccount)
account.id = 123
email_provider.search = MagicMock(return_value=("OK", [b'']))
result = process_folder(email_provider, "Empty", account, datetime(1970, 1, 1), mock_process_message_delay)
assert result == {
"messages_found": 0,
"new_messages": 0,
"errors": 0
}
email_provider.search = MagicMock(return_value=("OK", [b""]))
result = process_folder(
email_provider, "Empty", account, datetime(1970, 1, 1), MagicMock()
)
assert result == {"messages_found": 0, "new_messages": 0, "errors": 0}
def test_process_folder_error(email_provider):
account = MagicMock(spec=EmailAccount)
account.id = 123
mock_processor = MagicMock()
def raise_exception(*args):
raise Exception("Test error")
email_provider.search = raise_exception
result = process_folder(email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor)
assert result == {
"messages_found": 0,
"new_messages": 0,
"errors": 0
}
email_provider.search = raise_exception
result = process_folder(
email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor
)
assert result == {"messages_found": 0, "new_messages": 0, "errors": 0}