embeding tests

This commit is contained in:
Daniel O'Connell 2025-04-28 15:47:26 +02:00
parent 2d2f37536a
commit 869e5ac6b4
10 changed files with 101 additions and 73 deletions

View File

@ -172,6 +172,7 @@ def upgrade() -> None:
op.create_table( op.create_table(
"email_attachment", "email_attachment",
sa.Column("id", sa.BigInteger(), nullable=False), sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("mail_message_id", sa.BigInteger(), nullable=False), sa.Column("mail_message_id", sa.BigInteger(), nullable=False),
sa.Column("filename", sa.Text(), nullable=False), sa.Column("filename", sa.Text(), nullable=False),
sa.Column("content_type", sa.Text(), nullable=True), sa.Column("content_type", sa.Text(), nullable=True),
@ -187,6 +188,9 @@ def upgrade() -> None:
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["mail_message_id"], ["mail_message.id"], ondelete="CASCADE" ["mail_message_id"], ["mail_message.id"], ondelete="CASCADE"
), ),
sa.ForeignKeyConstraint(
["source_id"], ["source_item.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_table( op.create_table(

View File

@ -3,3 +3,5 @@ psycopg2-binary==2.9.9
pydantic==2.7.1 pydantic==2.7.1
alembic==1.13.1 alembic==1.13.1
dotenv==1.1.0 dotenv==1.1.0
voyageai==0.3.2
qdrant-client==1.9.0

View File

@ -1,4 +1,3 @@
celery==5.3.6 celery==5.3.6
openai==1.25.0 openai==1.25.0
pillow==10.3.0 pillow==10.3.0
qdrant-client==1.9.0

View File

@ -30,7 +30,7 @@ class SourceItem(Base):
mime_type = Column(Text) mime_type = Column(Text)
mail_message = relationship("MailMessage", back_populates="source", uselist=False) mail_message = relationship("MailMessage", back_populates="source", uselist=False)
attachments = relationship("EmailAttachment", back_populates="source", cascade="all, delete-orphan", uselist=False)
# Add table-level constraint and indexes # Add table-level constraint and indexes
__table_args__ = ( __table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"), CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
@ -85,6 +85,7 @@ class EmailAttachment(Base):
__tablename__ = 'email_attachment' __tablename__ = 'email_attachment'
id = Column(BigInteger, primary_key=True) id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
mail_message_id = Column(BigInteger, ForeignKey('mail_message.id', ondelete='CASCADE'), nullable=False) mail_message_id = Column(BigInteger, ForeignKey('mail_message.id', ondelete='CASCADE'), nullable=False)
filename = Column(Text, nullable=False) filename = Column(Text, nullable=False)
content_type = Column(Text) content_type = Column(Text)
@ -94,6 +95,7 @@ class EmailAttachment(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
mail_message = relationship("MailMessage", back_populates="attachments") mail_message = relationship("MailMessage", back_populates="attachments")
source = relationship("SourceItem", back_populates="attachments")
def as_payload(self) -> dict: def as_payload(self) -> dict:
return { return {

View File

@ -36,7 +36,7 @@ TYPES = {
} }
def get_type(mime_type: str) -> str: def get_modality(mime_type: str) -> str:
for type, mime_types in TYPES.items(): for type, mime_types in TYPES.items():
if mime_type in mime_types: if mime_type in mime_types:
return type return type
@ -109,7 +109,7 @@ def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
yield chunk yield chunk
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> list[str]: def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
""" """
Split text into chunks respecting semantic boundaries while staying within token limits. Split text into chunks respecting semantic boundaries while staying within token limits.
@ -149,10 +149,6 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
overlap_text.rfind("? ") overlap_text.rfind("? ")
) )
print(f"clean_break: {clean_break}")
print(f"overlap_text: {overlap_text}")
print(f"current: {current}")
if clean_break < 0: if clean_break < 0:
yield current yield current
current = "" current = ""
@ -160,18 +156,16 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
break_offset = -overlap_chars + clean_break + 1 break_offset = -overlap_chars + clean_break + 1
chunk = current[break_offset:].strip() chunk = current[break_offset:].strip()
print(f"chunk: {chunk}")
print(f"current: {current}")
yield current yield current
current = chunk current = chunk
if current: if current:
yield current.strip() yield current.strip()
def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]: def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
vo = voyageai.Client() vo = voyageai.Client()
chunks = chunk_text(text, MAX_TOKENS) return vo.embed(chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS), model=model)
return [vo.embed(chunk, model=model) for chunk in chunks]
def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]: def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
@ -182,4 +176,4 @@ def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n
if isinstance(content, bytes): if isinstance(content, bytes):
content = content.decode("utf-8") content = content.decode("utf-8")
return get_type(mime_type), embed_text(content, model, n_dimensions) return get_modality(mime_type), embed_text(content, model, n_dimensions)

View File

@ -168,9 +168,18 @@ def process_attachment(attachment: Attachment, message: MailMessage) -> EmailAtt
logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}") logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}")
return None return None
source_item = SourceItem(
modality=embedding.get_modality(attachment["content_type"]),
sha256=hashlib.sha256(real_content if real_content else str(attachment).encode()).digest(),
tags=message.source.tags,
byte_length=attachment["size"],
mime_type=attachment["content_type"],
)
return EmailAttachment( return EmailAttachment(
mail_message_id=message.id, source=source_item,
filename=attachment["filename"], filename=attachment["filename"],
mail_message=message,
content_type=attachment.get("content_type"), content_type=attachment.get("content_type"),
size=attachment.get("size"), size=attachment.get("size"),
content=content, content=content,
@ -242,7 +251,9 @@ def create_source_item(
db_session: Session, db_session: Session,
message_hash: bytes, message_hash: bytes,
account_tags: list[str], account_tags: list[str],
raw_email_size: int, raw_size: int,
modality: str = "mail",
mime_type: str = "message/rfc822",
) -> SourceItem: ) -> SourceItem:
""" """
Create a new source item record. Create a new source item record.
@ -251,17 +262,17 @@ def create_source_item(
db_session: Database session db_session: Database session
message_hash: SHA-256 hash of message message_hash: SHA-256 hash of message
account_tags: Tags from the email account account_tags: Tags from the email account
raw_email_size: Size of raw email in bytes raw_size: Size of raw email in bytes
Returns: Returns:
Newly created SourceItem Newly created SourceItem
""" """
source_item = SourceItem( source_item = SourceItem(
modality="mail", modality=modality,
sha256=message_hash, sha256=message_hash,
tags=account_tags, tags=account_tags,
byte_length=raw_email_size, byte_length=raw_size,
mime_type="message/rfc822", mime_type=mime_type,
embed_status="RAW" embed_status="RAW"
) )
db_session.add(source_item) db_session.add(source_item)
@ -271,7 +282,7 @@ def create_source_item(
def create_mail_message( def create_mail_message(
db_session: Session, db_session: Session,
source_id: int, source_item: SourceItem,
parsed_email: EmailMessage, parsed_email: EmailMessage,
folder: str, folder: str,
) -> MailMessage: ) -> MailMessage:
@ -288,7 +299,7 @@ def create_mail_message(
Newly created MailMessage Newly created MailMessage
""" """
mail_message = MailMessage( mail_message = MailMessage(
source_id=source_id, source=source_item,
message_id=parsed_email["message_id"], message_id=parsed_email["message_id"],
subject=parsed_email["subject"], subject=parsed_email["subject"],
sender=parsed_email["sender"], sender=parsed_email["sender"],
@ -462,15 +473,17 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
def vectorize_email(email: MailMessage) -> list[float]: def vectorize_email(email: MailMessage) -> list[float]:
qdrant_client = get_qdrant_client() qdrant_client = get_qdrant_client()
vector_id = uuid.uuid4() chunks = embedding.embed_text(email.body_raw)
payloads = [email.as_payload()] * len(chunks)
vector_ids = [str(uuid.uuid4()) for _ in chunks]
upsert_vectors( upsert_vectors(
client=qdrant_client, client=qdrant_client,
collection_name="mail", collection_name="mail",
ids=[str(vector_id)], ids=vector_ids,
vectors=[embedding.embed_text(email.body_raw)], vectors=chunks,
payloads=[email.as_payload()], payloads=payloads,
) )
vector_ids = [f"mail/{vector_id}"] vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
embeds = defaultdict(list) embeds = defaultdict(list)
for attachment in email.attachments: for attachment in email.attachments:
@ -478,8 +491,11 @@ def vectorize_email(email: MailMessage) -> list[float]:
content = pathlib.Path(attachment.file_path).read_bytes() content = pathlib.Path(attachment.file_path).read_bytes()
else: else:
content = attachment.content content = attachment.content
collection, vector = embedding.embed(attachment.content_type, content) collection, vectors = embedding.embed(attachment.content_type, content)
embeds[collection].append((str(uuid.uuid4()), vector, attachment.as_payload())) attachment.source.vector_ids = vector_ids
embeds[collection].extend(
(str(uuid.uuid4()), vector, attachment.as_payload()) for vector in vectors
)
for collection, embeds in embeds.items(): for collection, embeds in embeds.items():
ids, vectors, payloads = zip(*embeds) ids, vectors, payloads = zip(*embeds)

View File

@ -63,8 +63,7 @@ def process_message(
return None return None
source_item = create_source_item(db, message_hash, account.tags, len(raw_email)) source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
mail_message = create_mail_message(db, source_item, parsed_email, folder)
mail_message = create_mail_message(db, source_item.id, parsed_email, folder)
source_item.vector_ids = vectorize_email(mail_message) source_item.vector_ids = vectorize_email(mail_message)
db.commit() db.commit()

View File

@ -1,5 +1,5 @@
import pytest import pytest
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, MAX_TOKENS, approx_token_count from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,7 +1,9 @@
from unittest import mock
import pytest import pytest
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import patch
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common import embedding
from memory.workers.tasks.email import process_message from memory.workers.tasks.email import process_message
@ -36,6 +38,12 @@ VGhpcyBpcyBhIHRlc3QgYXR0YWNobWVudA==
--boundary123--""" --boundary123--"""
@pytest.fixture(autouse=True)
def mock_voyage_embed_text():
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
yield
@pytest.fixture @pytest.fixture
def test_email_account(db_session): def test_email_account(db_session):
"""Create a test email account for integration testing.""" """Create a test email account for integration testing."""

View File

@ -11,6 +11,7 @@ from unittest.mock import ANY, MagicMock, patch
import pytest import pytest
from memory.common.db.models import SourceItem, MailMessage, EmailAttachment, EmailAccount from memory.common.db.models import SourceItem, MailMessage, EmailAttachment, EmailAccount
from memory.common import settings from memory.common import settings
from memory.common import embedding
from memory.workers.email import ( from memory.workers.email import (
compute_message_hash, compute_message_hash,
create_source_item, create_source_item,
@ -236,6 +237,7 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
} }
message = MailMessage( message = MailMessage(
id=1, id=1,
source=SourceItem(tags=["test"]),
message_id=message_id, message_id=message_id,
sender="sender@example.com", sender="sender@example.com",
folder="INBOX", folder="INBOX",
@ -271,6 +273,7 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
} }
message = MailMessage( message = MailMessage(
id=1, id=1,
source=SourceItem(tags=["test"]),
message_id=message_id, message_id=message_id,
sender="sender@example.com", sender="sender@example.com",
folder="INBOX", folder="INBOX",
@ -293,6 +296,7 @@ def test_process_attachment_write_error():
} }
message = MailMessage( message = MailMessage(
id=1, id=1,
source=SourceItem(tags=["test"]),
message_id="<test@example.com>", message_id="<test@example.com>",
sender="sender@example.com", sender="sender@example.com",
folder="INBOX", folder="INBOX",
@ -340,6 +344,7 @@ def test_process_attachments_mixed():
] ]
message = MailMessage( message = MailMessage(
id=1, id=1,
source=SourceItem(tags=["test"]),
message_id="<test@example.com>", message_id="<test@example.com>",
sender="sender@example.com", sender="sender@example.com",
folder="INBOX", folder="INBOX",
@ -458,7 +463,7 @@ def test_create_source_item(db_session):
db_session=db_session, db_session=db_session,
message_hash=message_hash, message_hash=message_hash,
account_tags=account_tags, account_tags=account_tags,
raw_email_size=raw_email_size, raw_size=raw_email_size,
) )
# Verify the source item was created correctly # Verify the source item was created correctly
@ -549,7 +554,6 @@ def test_check_message_exists(
def test_create_mail_message(db_session): def test_create_mail_message(db_session):
source_item = SourceItem( source_item = SourceItem(
id=1,
modality="mail", modality="mail",
sha256=b"test_hash_bytes" + bytes(28), sha256=b"test_hash_bytes" + bytes(28),
tags=["test"], tags=["test"],
@ -557,7 +561,6 @@ def test_create_mail_message(db_session):
) )
db_session.add(source_item) db_session.add(source_item)
db_session.flush() db_session.flush()
source_id = source_item.id
parsed_email = { parsed_email = {
"message_id": "<test@example.com>", "message_id": "<test@example.com>",
"subject": "Test Subject", "subject": "Test Subject",
@ -574,21 +577,22 @@ def test_create_mail_message(db_session):
# Call function # Call function
mail_message = create_mail_message( mail_message = create_mail_message(
db_session=db_session, db_session=db_session,
source_id=source_id, source_item=source_item,
parsed_email=parsed_email, parsed_email=parsed_email,
folder=folder, folder=folder,
) )
db_session.commit()
attachments = db_session.query(EmailAttachment).filter(EmailAttachment.mail_message_id == mail_message.id).all() attachments = db_session.query(EmailAttachment).filter(EmailAttachment.mail_message_id == mail_message.id).all()
# Verify the mail message was created correctly # Verify the mail message was created correctly
assert isinstance(mail_message, MailMessage) assert isinstance(mail_message, MailMessage)
assert mail_message.source_id == source_id assert mail_message.source_id == source_item.id
assert mail_message.message_id == parsed_email["message_id"] assert mail_message.message_id == parsed_email["message_id"]
assert mail_message.subject == parsed_email["subject"] assert mail_message.subject == parsed_email["subject"]
assert mail_message.sender == parsed_email["sender"] assert mail_message.sender == parsed_email["sender"]
assert mail_message.recipients == parsed_email["recipients"] assert mail_message.recipients == parsed_email["recipients"]
assert mail_message.sent_at == parsed_email["sent_at"] assert mail_message.sent_at.isoformat()[:-6] == parsed_email["sent_at"].isoformat()
assert mail_message.body_raw == parsed_email["body"] assert mail_message.body_raw == parsed_email["body"]
assert mail_message.attachments == attachments assert mail_message.attachments == attachments
@ -670,22 +674,15 @@ def test_process_folder_error(email_provider):
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4): def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
source_item = SourceItem( mail_message = MailMessage(
id=1, source=SourceItem(
modality="mail", modality="mail",
sha256=b"test_hash" + bytes(24), sha256=b"test_hash" + bytes(24),
tags=["test"], tags=["test"],
byte_length=100, byte_length=100,
mime_type="message/rfc822", mime_type="message/rfc822",
embed_status="RAW", embed_status="RAW",
) ),
db_session.add(source_item)
db_session.flush()
# Create mail message
mail_message = MailMessage(
id=1,
source_id=1,
message_id="<test-vector@example.com>", message_id="<test-vector@example.com>",
subject="Test Vectorization", subject="Test Vectorization",
sender="sender@example.com", sender="sender@example.com",
@ -696,7 +693,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
db_session.add(mail_message) db_session.add(mail_message)
db_session.flush() db_session.flush()
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536): with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
vector_ids = vectorize_email(mail_message) vector_ids = vectorize_email(mail_message)
assert len(vector_ids) == 1 assert len(vector_ids) == 1
@ -704,22 +701,15 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
source_item = SourceItem( mail_message = MailMessage(
id=2, source=SourceItem(
modality="mail", modality="mail",
sha256=b"test_hash2" + bytes(24), sha256=b"test_hash" + bytes(24),
tags=["test"], tags=["test"],
byte_length=200, byte_length=100,
mime_type="message/rfc822", mime_type="message/rfc822",
embed_status="RAW", embed_status="RAW",
) ),
db_session.add(source_item)
db_session.flush()
# Create mail message
mail_message = MailMessage(
id=2,
source_id=2,
message_id="<test-vector-attach@example.com>", message_id="<test-vector-attach@example.com>",
subject="Test Vectorization with Attachments", subject="Test Vectorization with Attachments",
sender="sender@example.com", sender="sender@example.com",
@ -732,33 +722,47 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
# Add two attachments - one with content and one with file_path # Add two attachments - one with content and one with file_path
attachment1 = EmailAttachment( attachment1 = EmailAttachment(
id=1,
mail_message_id=mail_message.id, mail_message_id=mail_message.id,
filename="inline.txt", filename="inline.txt",
content_type="text/plain", content_type="text/plain",
size=100, size=100,
content=base64.b64encode(b"This is inline content"), content=base64.b64encode(b"This is inline content"),
file_path=None, file_path=None,
source=SourceItem(
modality="doc",
sha256=b"test_hash1" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="text/plain",
embed_status="RAW",
),
) )
file_path = mail_message.attachments_path / "stored.txt" file_path = mail_message.attachments_path / "stored.txt"
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_bytes(b"This is stored content") file_path.write_bytes(b"This is stored content")
attachment2 = EmailAttachment( attachment2 = EmailAttachment(
id=2,
mail_message_id=mail_message.id, mail_message_id=mail_message.id,
filename="stored.txt", filename="stored.txt",
content_type="text/plain", content_type="text/plain",
size=200, size=200,
content=None, content=None,
file_path=str(file_path), file_path=str(file_path),
source=SourceItem(
modality="doc",
sha256=b"test_hash2" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="text/plain",
embed_status="RAW",
),
) )
db_session.add_all([attachment1, attachment2]) db_session.add_all([attachment1, attachment2])
db_session.flush() db_session.flush()
# Mock embedding functions but use real qdrant # Mock embedding functions but use real qdrant
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536): with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
# Call the function # Call the function
vector_ids = vectorize_email(mail_message) vector_ids = vectorize_email(mail_message)