From 869e5ac6b46edc3e45157eb4b02acac01e4db890 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 28 Apr 2025 15:47:26 +0200 Subject: [PATCH] embeding tests --- .../20250427_171537_initial_structure.py | 4 + requirements-common.txt | 4 +- requirements-workers.txt | 1 - src/memory/common/db/models.py | 4 +- src/memory/common/embedding.py | 16 ++-- src/memory/workers/email.py | 46 ++++++---- src/memory/workers/tasks/email.py | 3 +- tests/memory/common/test_embedding.py | 2 +- .../memory/workers/tasks/test_email_tasks.py | 10 ++- tests/memory/workers/test_email.py | 84 ++++++++++--------- 10 files changed, 101 insertions(+), 73 deletions(-) diff --git a/db/migrations/versions/20250427_171537_initial_structure.py b/db/migrations/versions/20250427_171537_initial_structure.py index a355857..d1a7a43 100644 --- a/db/migrations/versions/20250427_171537_initial_structure.py +++ b/db/migrations/versions/20250427_171537_initial_structure.py @@ -172,6 +172,7 @@ def upgrade() -> None: op.create_table( "email_attachment", sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.BigInteger(), nullable=False), sa.Column("mail_message_id", sa.BigInteger(), nullable=False), sa.Column("filename", sa.Text(), nullable=False), sa.Column("content_type", sa.Text(), nullable=True), @@ -187,6 +188,9 @@ def upgrade() -> None: sa.ForeignKeyConstraint( ["mail_message_id"], ["mail_message.id"], ondelete="CASCADE" ), + sa.ForeignKeyConstraint( + ["source_id"], ["source_item.id"], ondelete="CASCADE" + ), sa.PrimaryKeyConstraint("id"), ) op.create_table( diff --git a/requirements-common.txt b/requirements-common.txt index 7f78901..92052c5 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -2,4 +2,6 @@ sqlalchemy==2.0.30 psycopg2-binary==2.9.9 pydantic==2.7.1 alembic==1.13.1 -dotenv==1.1.0 \ No newline at end of file +dotenv==1.1.0 +voyageai==0.3.2 +qdrant-client==1.9.0 \ No newline at end of file diff --git a/requirements-workers.txt b/requirements-workers.txt index 9423d63..ae0428c 100644 --- a/requirements-workers.txt +++ b/requirements-workers.txt @@ -1,4 +1,3 @@ celery==5.3.6 openai==1.25.0 pillow==10.3.0 -qdrant-client==1.9.0 \ No newline at end of file diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index 490c0bf..74e0686 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -30,7 +30,7 @@ class SourceItem(Base): mime_type = Column(Text) mail_message = relationship("MailMessage", back_populates="source", uselist=False) - + attachments = relationship("EmailAttachment", back_populates="source", cascade="all, delete-orphan", uselist=False) # Add table-level constraint and indexes __table_args__ = ( CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"), @@ -85,6 +85,7 @@ class EmailAttachment(Base): __tablename__ = 'email_attachment' id = Column(BigInteger, primary_key=True) + source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False) mail_message_id = Column(BigInteger, ForeignKey('mail_message.id', ondelete='CASCADE'), nullable=False) filename = Column(Text, nullable=False) content_type = Column(Text) @@ -94,6 +95,7 @@ class EmailAttachment(Base): created_at = Column(DateTime(timezone=True), server_default=func.now()) mail_message = relationship("MailMessage", back_populates="attachments") + source = relationship("SourceItem", back_populates="attachments") def as_payload(self) -> dict: return { diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 6d75048..bedc110 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -36,7 +36,7 @@ TYPES = { } -def get_type(mime_type: str) -> str: +def get_modality(mime_type: str) -> str: for type, mime_types in TYPES.items(): if mime_type in mime_types: return type @@ -109,7 +109,7 @@ def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]: yield chunk -def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> list[str]: +def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]: """ Split text into chunks respecting semantic boundaries while staying within token limits. @@ -149,10 +149,6 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T overlap_text.rfind("? ") ) - print(f"clean_break: {clean_break}") - print(f"overlap_text: {overlap_text}") - print(f"current: {current}") - if clean_break < 0: yield current current = "" @@ -160,18 +156,16 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T break_offset = -overlap_chars + clean_break + 1 chunk = current[break_offset:].strip() - print(f"chunk: {chunk}") - print(f"current: {current}") yield current current = chunk + if current: yield current.strip() def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]: vo = voyageai.Client() - chunks = chunk_text(text, MAX_TOKENS) - return [vo.embed(chunk, model=model) for chunk in chunks] + return vo.embed(chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS), model=model) def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]: @@ -182,4 +176,4 @@ def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n if isinstance(content, bytes): content = content.decode("utf-8") - return get_type(mime_type), embed_text(content, model, n_dimensions) + return get_modality(mime_type), embed_text(content, model, n_dimensions) diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index e054aad..faae8d3 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -168,9 +168,18 @@ def process_attachment(attachment: Attachment, message: MailMessage) -> EmailAtt logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}") return None + source_item = SourceItem( + modality=embedding.get_modality(attachment["content_type"]), + sha256=hashlib.sha256(real_content if real_content else str(attachment).encode()).digest(), + tags=message.source.tags, + byte_length=attachment["size"], + mime_type=attachment["content_type"], + ) + return EmailAttachment( - mail_message_id=message.id, + source=source_item, filename=attachment["filename"], + mail_message=message, content_type=attachment.get("content_type"), size=attachment.get("size"), content=content, @@ -242,7 +251,9 @@ def create_source_item( db_session: Session, message_hash: bytes, account_tags: list[str], - raw_email_size: int, + raw_size: int, + modality: str = "mail", + mime_type: str = "message/rfc822", ) -> SourceItem: """ Create a new source item record. @@ -251,17 +262,17 @@ def create_source_item( db_session: Database session message_hash: SHA-256 hash of message account_tags: Tags from the email account - raw_email_size: Size of raw email in bytes + raw_size: Size of raw email in bytes Returns: Newly created SourceItem """ source_item = SourceItem( - modality="mail", + modality=modality, sha256=message_hash, tags=account_tags, - byte_length=raw_email_size, - mime_type="message/rfc822", + byte_length=raw_size, + mime_type=mime_type, embed_status="RAW" ) db_session.add(source_item) @@ -271,7 +282,7 @@ def create_source_item( def create_mail_message( db_session: Session, - source_id: int, + source_item: SourceItem, parsed_email: EmailMessage, folder: str, ) -> MailMessage: @@ -288,7 +299,7 @@ def create_mail_message( Newly created MailMessage """ mail_message = MailMessage( - source_id=source_id, + source=source_item, message_id=parsed_email["message_id"], subject=parsed_email["subject"], sender=parsed_email["sender"], @@ -462,15 +473,17 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, def vectorize_email(email: MailMessage) -> list[float]: qdrant_client = get_qdrant_client() - vector_id = uuid.uuid4() + chunks = embedding.embed_text(email.body_raw) + payloads = [email.as_payload()] * len(chunks) + vector_ids = [str(uuid.uuid4()) for _ in chunks] upsert_vectors( client=qdrant_client, collection_name="mail", - ids=[str(vector_id)], - vectors=[embedding.embed_text(email.body_raw)], - payloads=[email.as_payload()], + ids=vector_ids, + vectors=chunks, + payloads=payloads, ) - vector_ids = [f"mail/{vector_id}"] + vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids] embeds = defaultdict(list) for attachment in email.attachments: @@ -478,8 +491,11 @@ def vectorize_email(email: MailMessage) -> list[float]: content = pathlib.Path(attachment.file_path).read_bytes() else: content = attachment.content - collection, vector = embedding.embed(attachment.content_type, content) - embeds[collection].append((str(uuid.uuid4()), vector, attachment.as_payload())) + collection, vectors = embedding.embed(attachment.content_type, content) + attachment.source.vector_ids = vector_ids + embeds[collection].extend( + (str(uuid.uuid4()), vector, attachment.as_payload()) for vector in vectors + ) for collection, embeds in embeds.items(): ids, vectors, payloads = zip(*embeds) diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index a37ad16..d6a9995 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -63,8 +63,7 @@ def process_message( return None source_item = create_source_item(db, message_hash, account.tags, len(raw_email)) - - mail_message = create_mail_message(db, source_item.id, parsed_email, folder) + mail_message = create_mail_message(db, source_item, parsed_email, folder) source_item.vector_ids = vectorize_email(mail_message) db.commit() diff --git a/tests/memory/common/test_embedding.py b/tests/memory/common/test_embedding.py index 69c05da..fa159d3 100644 --- a/tests/memory/common/test_embedding.py +++ b/tests/memory/common/test_embedding.py @@ -1,5 +1,5 @@ import pytest -from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, MAX_TOKENS, approx_token_count +from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN @pytest.mark.parametrize( diff --git a/tests/memory/workers/tasks/test_email_tasks.py b/tests/memory/workers/tasks/test_email_tasks.py index acd5abd..665ea9e 100644 --- a/tests/memory/workers/tasks/test_email_tasks.py +++ b/tests/memory/workers/tasks/test_email_tasks.py @@ -1,7 +1,9 @@ +from unittest import mock import pytest from datetime import datetime, timedelta - +from unittest.mock import patch from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment +from memory.common import embedding from memory.workers.tasks.email import process_message @@ -36,6 +38,12 @@ VGhpcyBpcyBhIHRlc3QgYXR0YWNobWVudA== --boundary123--""" +@pytest.fixture(autouse=True) +def mock_voyage_embed_text(): + with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]): + yield + + @pytest.fixture def test_email_account(db_session): """Create a test email account for integration testing.""" diff --git a/tests/memory/workers/test_email.py b/tests/memory/workers/test_email.py index 3aec3fc..85e0063 100644 --- a/tests/memory/workers/test_email.py +++ b/tests/memory/workers/test_email.py @@ -11,6 +11,7 @@ from unittest.mock import ANY, MagicMock, patch import pytest from memory.common.db.models import SourceItem, MailMessage, EmailAttachment, EmailAccount from memory.common import settings +from memory.common import embedding from memory.workers.email import ( compute_message_hash, create_source_item, @@ -236,6 +237,7 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id) } message = MailMessage( id=1, + source=SourceItem(tags=["test"]), message_id=message_id, sender="sender@example.com", folder="INBOX", @@ -271,6 +273,7 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id): } message = MailMessage( id=1, + source=SourceItem(tags=["test"]), message_id=message_id, sender="sender@example.com", folder="INBOX", @@ -293,6 +296,7 @@ def test_process_attachment_write_error(): } message = MailMessage( id=1, + source=SourceItem(tags=["test"]), message_id="", sender="sender@example.com", folder="INBOX", @@ -340,6 +344,7 @@ def test_process_attachments_mixed(): ] message = MailMessage( id=1, + source=SourceItem(tags=["test"]), message_id="", sender="sender@example.com", folder="INBOX", @@ -458,7 +463,7 @@ def test_create_source_item(db_session): db_session=db_session, message_hash=message_hash, account_tags=account_tags, - raw_email_size=raw_email_size, + raw_size=raw_email_size, ) # Verify the source item was created correctly @@ -549,7 +554,6 @@ def test_check_message_exists( def test_create_mail_message(db_session): source_item = SourceItem( - id=1, modality="mail", sha256=b"test_hash_bytes" + bytes(28), tags=["test"], @@ -557,7 +561,6 @@ def test_create_mail_message(db_session): ) db_session.add(source_item) db_session.flush() - source_id = source_item.id parsed_email = { "message_id": "", "subject": "Test Subject", @@ -574,21 +577,22 @@ def test_create_mail_message(db_session): # Call function mail_message = create_mail_message( db_session=db_session, - source_id=source_id, + source_item=source_item, parsed_email=parsed_email, folder=folder, ) + db_session.commit() attachments = db_session.query(EmailAttachment).filter(EmailAttachment.mail_message_id == mail_message.id).all() # Verify the mail message was created correctly assert isinstance(mail_message, MailMessage) - assert mail_message.source_id == source_id + assert mail_message.source_id == source_item.id assert mail_message.message_id == parsed_email["message_id"] assert mail_message.subject == parsed_email["subject"] assert mail_message.sender == parsed_email["sender"] assert mail_message.recipients == parsed_email["recipients"] - assert mail_message.sent_at == parsed_email["sent_at"] + assert mail_message.sent_at.isoformat()[:-6] == parsed_email["sent_at"].isoformat() assert mail_message.body_raw == parsed_email["body"] assert mail_message.attachments == attachments @@ -670,22 +674,15 @@ def test_process_folder_error(email_provider): def test_vectorize_email_basic(db_session, qdrant, mock_uuid4): - source_item = SourceItem( - id=1, - modality="mail", - sha256=b"test_hash" + bytes(24), - tags=["test"], - byte_length=100, - mime_type="message/rfc822", - embed_status="RAW", - ) - db_session.add(source_item) - db_session.flush() - - # Create mail message mail_message = MailMessage( - id=1, - source_id=1, + source=SourceItem( + modality="mail", + sha256=b"test_hash" + bytes(24), + tags=["test"], + byte_length=100, + mime_type="message/rfc822", + embed_status="RAW", + ), message_id="", subject="Test Vectorization", sender="sender@example.com", @@ -696,7 +693,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4): db_session.add(mail_message) db_session.flush() - with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536): + with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]): vector_ids = vectorize_email(mail_message) assert len(vector_ids) == 1 @@ -704,22 +701,15 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4): def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): - source_item = SourceItem( - id=2, - modality="mail", - sha256=b"test_hash2" + bytes(24), - tags=["test"], - byte_length=200, - mime_type="message/rfc822", - embed_status="RAW", - ) - db_session.add(source_item) - db_session.flush() - - # Create mail message mail_message = MailMessage( - id=2, - source_id=2, + source=SourceItem( + modality="mail", + sha256=b"test_hash" + bytes(24), + tags=["test"], + byte_length=100, + mime_type="message/rfc822", + embed_status="RAW", + ), message_id="", subject="Test Vectorization with Attachments", sender="sender@example.com", @@ -732,33 +722,47 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): # Add two attachments - one with content and one with file_path attachment1 = EmailAttachment( - id=1, mail_message_id=mail_message.id, filename="inline.txt", content_type="text/plain", size=100, content=base64.b64encode(b"This is inline content"), file_path=None, + source=SourceItem( + modality="doc", + sha256=b"test_hash1" + bytes(24), + tags=["test"], + byte_length=100, + mime_type="text/plain", + embed_status="RAW", + ), ) file_path = mail_message.attachments_path / "stored.txt" file_path.parent.mkdir(parents=True, exist_ok=True) file_path.write_bytes(b"This is stored content") attachment2 = EmailAttachment( - id=2, mail_message_id=mail_message.id, filename="stored.txt", content_type="text/plain", size=200, content=None, file_path=str(file_path), + source=SourceItem( + modality="doc", + sha256=b"test_hash2" + bytes(24), + tags=["test"], + byte_length=100, + mime_type="text/plain", + embed_status="RAW", + ), ) db_session.add_all([attachment1, attachment2]) db_session.flush() # Mock embedding functions but use real qdrant - with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536): + with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]): # Call the function vector_ids = vectorize_email(mail_message)