diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index df858ee..3804356 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -6,7 +6,7 @@ import re from contextlib import contextmanager from datetime import datetime from email.utils import parsedate_to_datetime - +from typing import Generator from sqlalchemy.orm import Session from memory.common.db.models import EmailAccount, MailMessage, SourceItem @@ -227,12 +227,15 @@ def check_message_exists(db_session: Session, message_id: str, message_hash: byt Returns: True if message exists, False otherwise """ - return ( - # Check by message_id first (faster) - message_id and db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first() - # Then check by message_hash - or db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first() is not None - ) + # Check by message_id first (faster) + if message_id: + mail_message = db_session.query(MailMessage).filter(MailMessage.message_id == message_id).first() + if mail_message is not None: + return True + + # Then check by message_hash + source_item = db_session.query(SourceItem).filter(SourceItem.sha256 == message_hash).first() + return source_item is not None def extract_email_uid(msg_data: bytes) -> tuple[str, str]: @@ -316,12 +319,16 @@ def process_folder( Stats dictionary for the folder """ new_messages, errors = 0, 0 + emails = [] # Initialize to avoid UnboundLocalError try: emails = fetch_email_since(conn, folder, since_date) for uid, raw_email in emails: try: + # Import process_message here to avoid circular imports + from memory.workers.tasks.email import process_message + task = process_message.delay( account_id=account.id, message_id=uid, @@ -346,7 +353,7 @@ def process_folder( @contextmanager -def imap_connection(account: EmailAccount) -> imaplib.IMAP4_SSL: +def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, None]: conn = imaplib.IMAP4_SSL( host=account.imap_server, port=account.imap_port diff --git a/tests/memory/workers/tasks/conftest.py b/tests/memory/workers/tasks/conftest.py index feaed21..e5f0270 100644 --- a/tests/memory/workers/tasks/conftest.py +++ b/tests/memory/workers/tasks/conftest.py @@ -56,7 +56,6 @@ def run_alembic_migrations(db_name: str) -> None: project_root = Path(__file__).parent.parent.parent.parent.parent alembic_ini = project_root / "db" / "migrations" / "alembic.ini" - breakpoint() subprocess.run( ["alembic", "-c", str(alembic_ini), "upgrade", "head"], env={**os.environ, "DATABASE_URL": settings.make_db_url(db=db_name)}, diff --git a/tests/memory/workers/tasks/test_email.py b/tests/memory/workers/tasks/test_email.py index 0f52083..e19971b 100644 --- a/tests/memory/workers/tasks/test_email.py +++ b/tests/memory/workers/tasks/test_email.py @@ -4,10 +4,11 @@ import email.mime.text import email.mime.base from datetime import datetime from email.utils import formatdate -from unittest.mock import ANY +from unittest.mock import ANY, MagicMock, patch import pytest - +import imaplib from memory.common.db.models import SourceItem +from memory.common.db.models import MailMessage, EmailAccount from memory.workers.email import ( compute_message_hash, create_source_item, @@ -17,6 +18,12 @@ from memory.workers.email import ( extract_email_uid, extract_recipients, parse_email_message, + check_message_exists, + create_mail_message, + fetch_email, + fetch_email_since, + process_folder, + imap_connection, ) @@ -323,3 +330,275 @@ def test_create_source_item(db_session): fetched_item = db_session.query(SourceItem).filter_by(id=source_item.id).one() assert fetched_item is not None assert fetched_item.sha256 == message_hash + + +@pytest.mark.parametrize( + "setup_db, message_id, message_hash, expected_exists", + [ + # Test by message ID + ( + lambda db: ( + # First create source_item to satisfy foreign key constraint + db.add(SourceItem( + id=1, + modality="mail", + sha256=b'some_hash_bytes' + bytes(28), + tags=["test"], + byte_length=100, + mime_type="message/rfc822", + embed_status="RAW" + )), + db.flush(), + # Then create mail_message + db.add(MailMessage( + source_id=1, + message_id="", + subject="Test", + sender="test@example.com", + recipients=["recipient@example.com"], + body_raw="Test body" + )) + ), + "", + b"unmatched_hash", + True + ), + # Test by non-existent message ID + ( + lambda db: None, + "", + b"unmatched_hash", + False + ), + # Test by hash + ( + lambda db: db.add(SourceItem( + modality="mail", + sha256=b'test_hash_bytes' + bytes(28), + tags=["test"], + byte_length=100, + mime_type="message/rfc822", + embed_status="RAW" + )), + "", + b'test_hash_bytes' + bytes(28), + True + ), + # Test by non-existent hash + ( + lambda db: None, + "", + b'different_hash_' + bytes(28), + False + ), + ] +) +def test_check_message_exists(db_session, setup_db, message_id, message_hash, expected_exists): + # Setup test data + if setup_db: + setup_db(db_session) + db_session.flush() + + # Test the function + assert check_message_exists(db_session, message_id, message_hash) == expected_exists + + +def test_create_mail_message(db_session): + source_id = 1 + parsed_email = { + "message_id": "", + "subject": "Test Subject", + "sender": "sender@example.com", + "recipients": ["recipient@example.com"], + "sent_at": datetime(2023, 1, 1, 12, 0, 0), + "body": "Test body content", + "attachments": [{"filename": "test.txt", "content_type": "text/plain", "size": 100}] + } + folder = "INBOX" + + # Call function + mail_message = create_mail_message( + db_session=db_session, + source_id=source_id, + parsed_email=parsed_email, + folder=folder + ) + + # Verify the mail message was created correctly + assert isinstance(mail_message, MailMessage) + assert mail_message.source_id == source_id + assert mail_message.message_id == parsed_email["message_id"] + assert mail_message.subject == parsed_email["subject"] + assert mail_message.sender == parsed_email["sender"] + assert mail_message.recipients == parsed_email["recipients"] + assert mail_message.sent_at == parsed_email["sent_at"] + assert mail_message.body_raw == parsed_email["body"] + assert mail_message.attachments == {"items": parsed_email["attachments"], "folder": folder} + + +@pytest.mark.parametrize( + "fetch_return, fetch_side_effect, extract_uid_return, expected_result", + [ + # Success case + (('OK', ['mock_data']), None, ("12345", b'raw email content'), ("12345", b'raw email content')), + # IMAP error + (('NO', []), None, None, None), + # Exception case + (None, Exception("Test error"), None, None), + ] +) +@patch('memory.workers.email.extract_email_uid') +def test_fetch_email( + mock_extract_email_uid, fetch_return, fetch_side_effect, extract_uid_return, expected_result +): + conn = MagicMock(spec=imaplib.IMAP4_SSL) + + # Configure mocks + if fetch_side_effect: + conn.fetch.side_effect = fetch_side_effect + else: + conn.fetch.return_value = fetch_return + + if extract_uid_return: + mock_extract_email_uid.return_value = extract_uid_return + + uid = "12345" + + # Call function + result = fetch_email(conn, uid) + + # Verify expectations + assert result == expected_result + + # Verify fetch was called if no exception + if not fetch_side_effect: + conn.fetch.assert_called_once_with(uid, '(UID RFC822)') + + +@pytest.mark.parametrize( + "select_return, search_return, select_side_effect, expected_calls, expected_result", + [ + # Successful case with multiple messages + ( + ('OK', [b'1']), + ('OK', [b'1 2 3']), + None, + 3, + [("1", b'email1'), ("2", b'email2'), ("3", b'email3')] + ), + # No messages found case + ( + ('OK', [b'0']), + ('OK', [b'']), + None, + 0, + [] + ), + # Error in select + ( + ('NO', [b'Error']), + None, + None, + 0, + [] + ), + # Error in search + ( + ('OK', [b'1']), + ('NO', [b'Error']), + None, + 0, + [] + ), + # Exception in select + ( + None, + None, + Exception("Test error"), + 0, + [] + ), + ] +) +@patch('memory.workers.email.fetch_email') +def test_fetch_email_since( + mock_fetch_email, select_return, search_return, select_side_effect, expected_calls, expected_result +): + conn = MagicMock(spec=imaplib.IMAP4_SSL) + + # Configure mocks based on parameters + if select_side_effect: + conn.select.side_effect = select_side_effect + else: + conn.select.return_value = select_return + + if search_return: + conn.search.return_value = search_return + + # Configure fetch_email mock if needed + if expected_calls > 0: + mock_fetch_email.side_effect = [ + (f"{i+1}", f"email{i+1}".encode()) for i in range(expected_calls) + ] + + folder = "INBOX" + since_date = datetime(2023, 1, 1) + + result = fetch_email_since(conn, folder, since_date) + + assert mock_fetch_email.call_count == expected_calls + assert result == expected_result + + +@patch('memory.workers.email.fetch_email_since') +def test_process_folder_error(mock_fetch_email_since): + # Setup + conn = MagicMock(spec=imaplib.IMAP4_SSL) + folder = "INBOX" + account = MagicMock(spec=EmailAccount) + since_date = datetime(2023, 1, 1) + + # Test exception in fetch_email_since + mock_fetch_email_since.side_effect = Exception("Test error") + + # Call function + result = process_folder(conn, folder, account, since_date) + + # Verify + assert result["messages_found"] == 0 + assert result["new_messages"] == 0 + assert result["errors"] == 1 + + +@patch('memory.workers.tasks.email.process_message.delay') +@patch('memory.workers.email.fetch_email_since') +def test_process_folder(mock_fetch_email_since, mock_process_message_delay): + conn = MagicMock(spec=imaplib.IMAP4_SSL) + folder = "INBOX" + account = MagicMock(spec=EmailAccount) + account.id = 123 + since_date = datetime(2023, 1, 1) + + mock_fetch_email_since.return_value = [ + ("1", b'email1'), + ("2", b'email2'), + ] + + mock_process_message_delay.return_value = MagicMock() + + with patch('builtins.__import__', side_effect=__import__): + result = process_folder(conn, folder, account, since_date) + + mock_fetch_email_since.assert_called_once_with(conn, folder, since_date) + assert mock_process_message_delay.call_count == 2 + + mock_process_message_delay.assert_any_call( + account_id=account.id, + message_id="1", + folder=folder, + raw_email='email1' + ) + + assert result["messages_found"] == 2 + assert result["new_messages"] == 2 + assert result["errors"] == 0