tests for email

This commit is contained in:
Daniel O'Connell 2025-04-27 17:51:25 +02:00
parent d1cac9ffd9
commit d3117a4e6a
3 changed files with 296 additions and 11 deletions

View File

@ -6,7 +6,7 @@ import re
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
from typing import Generator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from memory.common.db.models import EmailAccount, MailMessage, SourceItem from memory.common.db.models import EmailAccount, MailMessage, SourceItem
@ -227,12 +227,15 @@ def check_message_exists(db_session: Session, message_id: str, message_hash: byt
Returns: Returns:
True if message exists, False otherwise True if message exists, False otherwise
""" """
return ( # Check by message_id first (faster)
# Check by message_id first (faster) if message_id:
message_id and db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first() mail_message = db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first()
# Then check by message_hash if mail_message is not None:
or db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first() is not None return True
)
# Then check by message_hash
source_item = db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first()
return source_item is not None
def extract_email_uid(msg_data: bytes) -> tuple[str, str]: def extract_email_uid(msg_data: bytes) -> tuple[str, str]:
@ -316,12 +319,16 @@ def process_folder(
Stats dictionary for the folder Stats dictionary for the folder
""" """
new_messages, errors = 0, 0 new_messages, errors = 0, 0
emails = [] # Initialize to avoid UnboundLocalError
try: try:
emails = fetch_email_since(conn, folder, since_date) emails = fetch_email_since(conn, folder, since_date)
for uid, raw_email in emails: for uid, raw_email in emails:
try: try:
# Import process_message here to avoid circular imports
from memory.workers.tasks.email import process_message
task = process_message.delay( task = process_message.delay(
account_id=account.id, account_id=account.id,
message_id=uid, message_id=uid,
@ -346,7 +353,7 @@ def process_folder(
@contextmanager @contextmanager
def imap_connection(account: EmailAccount) -> imaplib.IMAP4_SSL: def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]:
conn = imaplib.IMAP4_SSL( conn = imaplib.IMAP4_SSL(
host=account.imap_server, host=account.imap_server,
port=account.imap_port port=account.imap_port

View File

@ -56,7 +56,6 @@ def run_alembic_migrations(db_name: str) -> None:
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
alembic_ini = project_root / "db" / "migrations" / "alembic.ini" alembic_ini = project_root / "db" / "migrations" / "alembic.ini"
breakpoint()
subprocess.run( subprocess.run(
["alembic", "-c", str(alembic_ini), "upgrade", "head"], ["alembic", "-c", str(alembic_ini), "upgrade", "head"],
env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)}, env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)},

View File

@ -4,10 +4,11 @@ import email.mime.text
import email.mime.base import email.mime.base
from datetime import datetime from datetime import datetime
from email.utils import formatdate from email.utils import formatdate
from unittest.mock import ANY from unittest.mock import ANY, MagicMock, patch
import pytest import pytest
import imaplib
from memory.common.db.models import SourceItem from memory.common.db.models import SourceItem
from memory.common.db.models import MailMessage, EmailAccount
from memory.workers.email import ( from memory.workers.email import (
compute_message_hash, compute_message_hash,
create_source_item, create_source_item,
@ -17,6 +18,12 @@ from memory.workers.email import (
extract_email_uid, extract_email_uid,
extract_recipients, extract_recipients,
parse_email_message, parse_email_message,
check_message_exists,
create_mail_message,
fetch_email,
fetch_email_since,
process_folder,
imap_connection,
) )
@ -323,3 +330,275 @@ def test_create_source_item(db_session):
fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one() fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one()
assert fetched_item is not None assert fetched_item is not None
assert fetched_item.sha256 == message_hash assert fetched_item.sha256 == message_hash
@pytest.mark.parametrize(
"setup_db, message_id, message_hash, expected_exists",
[
# Test by message ID
(
lambda db: (
# First create source_item to satisfy foreign key constraint
db.add(SourceItem(
id=1,
modality="mail",
sha256=b'some_hash_bytes' + bytes(28),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW"
)),
db.flush(),
# Then create mail_message
db.add(MailMessage(
source_id=1,
message_id="<test@example.com>",
subject="Test",
sender="test@example.com",
recipients=["recipient@example.com"],
body_raw="Test body"
))
),
"<test@example.com>",
b"unmatched_hash",
True
),
# Test by non-existent message ID
(
lambda db: None,
"<nonexistent@example.com>",
b"unmatched_hash",
False
),
# Test by hash
(
lambda db: db.add(SourceItem(
modality="mail",
sha256=b'test_hash_bytes' + bytes(28),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW"
)),
"",
b'test_hash_bytes' + bytes(28),
True
),
# Test by non-existent hash
(
lambda db: None,
"",
b'different_hash_' + bytes(28),
False
),
]
)
def test_check_message_exists(db_session, setup_db, message_id, message_hash, expected_exists):
# Setup test data
if setup_db:
setup_db(db_session)
db_session.flush()
# Test the function
assert check_message_exists(db_session, message_id, message_hash) == expected_exists
def test_create_mail_message(db_session):
source_id = 1
parsed_email = {
"message_id": "<test@example.com>",
"subject": "Test Subject",
"sender": "sender@example.com",
"recipients": ["recipient@example.com"],
"sent_at": datetime(2023, 1, 1, 12, 0, 0),
"body": "Test body content",
"attachments": [{"filename": "test.txt", "content_type": "text/plain", "size": 100}]
}
folder = "INBOX"
# Call function
mail_message = create_mail_message(
db_session=db_session,
source_id=source_id,
parsed_email=parsed_email,
folder=folder
)
# Verify the mail message was created correctly
assert isinstance(mail_message, MailMessage)
assert mail_message.source_id == source_id
assert mail_message.message_id == parsed_email["message_id"]
assert mail_message.subject == parsed_email["subject"]
assert mail_message.sender == parsed_email["sender"]
assert mail_message.recipients == parsed_email["recipients"]
assert mail_message.sent_at == parsed_email["sent_at"]
assert mail_message.body_raw == parsed_email["body"]
assert mail_message.attachments == {"items": parsed_email["attachments"], "folder": folder}
@pytest.mark.parametrize(
"fetch_return, fetch_side_effect, extract_uid_return, expected_result",
[
# Success case
(('OK', ['mock_data']), None, ("12345", b'raw email content'), ("12345", b'raw email content')),
# IMAP error
(('NO', []), None, None, None),
# Exception case
(None, Exception("Test error"), None, None),
]
)
@patch('memory.workers.email.extract_email_uid')
def test_fetch_email(
mock_extract_email_uid, fetch_return, fetch_side_effect, extract_uid_return, expected_result
):
conn = MagicMock(spec=imaplib.IMAP4_SSL)
# Configure mocks
if fetch_side_effect:
conn.fetch.side_effect = fetch_side_effect
else:
conn.fetch.return_value = fetch_return
if extract_uid_return:
mock_extract_email_uid.return_value = extract_uid_return
uid = "12345"
# Call function
result = fetch_email(conn, uid)
# Verify expectations
assert result == expected_result
# Verify fetch was called if no exception
if not fetch_side_effect:
conn.fetch.assert_called_once_with(uid, '(UID RFC822)')
@pytest.mark.parametrize(
"select_return, search_return, select_side_effect, expected_calls, expected_result",
[
# Successful case with multiple messages
(
('OK', [b'1']),
('OK', [b'1 2 3']),
None,
3,
[("1", b'email1'), ("2", b'email2'), ("3", b'email3')]
),
# No messages found case
(
('OK', [b'0']),
('OK', [b'']),
None,
0,
[]
),
# Error in select
(
('NO', [b'Error']),
None,
None,
0,
[]
),
# Error in search
(
('OK', [b'1']),
('NO', [b'Error']),
None,
0,
[]
),
# Exception in select
(
None,
None,
Exception("Test error"),
0,
[]
),
]
)
@patch('memory.workers.email.fetch_email')
def test_fetch_email_since(
mock_fetch_email, select_return, search_return, select_side_effect, expected_calls, expected_result
):
conn = MagicMock(spec=imaplib.IMAP4_SSL)
# Configure mocks based on parameters
if select_side_effect:
conn.select.side_effect = select_side_effect
else:
conn.select.return_value = select_return
if search_return:
conn.search.return_value = search_return
# Configure fetch_email mock if needed
if expected_calls > 0:
mock_fetch_email.side_effect = [
(f"{i+1}", f"email{i+1}".encode()) for i in range(expected_calls)
]
folder = "INBOX"
since_date = datetime(2023, 1, 1)
result = fetch_email_since(conn, folder, since_date)
assert mock_fetch_email.call_count == expected_calls
assert result == expected_result
@patch('memory.workers.email.fetch_email_since')
def test_process_folder_error(mock_fetch_email_since):
# Setup
conn = MagicMock(spec=imaplib.IMAP4_SSL)
folder = "INBOX"
account = MagicMock(spec=EmailAccount)
since_date = datetime(2023, 1, 1)
# Test exception in fetch_email_since
mock_fetch_email_since.side_effect = Exception("Test error")
# Call function
result = process_folder(conn, folder, account, since_date)
# Verify
assert result["messages_found"] == 0
assert result["new_messages"] == 0
assert result["errors"] == 1
@patch('memory.workers.tasks.email.process_message.delay')
@patch('memory.workers.email.fetch_email_since')
def test_process_folder(mock_fetch_email_since, mock_process_message_delay):
conn = MagicMock(spec=imaplib.IMAP4_SSL)
folder = "INBOX"
account = MagicMock(spec=EmailAccount)
account.id = 123
since_date = datetime(2023, 1, 1)
mock_fetch_email_since.return_value = [
("1", b'email1'),
("2", b'email2'),
]
mock_process_message_delay.return_value = MagicMock()
with patch('builtins.__import__', side_effect=__import__):
result = process_folder(conn, folder, account, since_date)
mock_fetch_email_since.assert_called_once_with(conn, folder, since_date)
assert mock_process_message_delay.call_count == 2
mock_process_message_delay.assert_any_call(
account_id=account.id,
message_id="1",
folder=folder,
raw_email='email1'
)
assert result["messages_found"] == 2
assert result["new_messages"] == 2
assert result["errors"] == 0