embeding tests

This commit is contained in:
Daniel O'Connell 2025-04-28 15:47:26 +02:00
parent 2d2f37536a
commit 869e5ac6b4
10 changed files with 101 additions and 73 deletions

View File

@ -172,6 +172,7 @@ def upgrade() -> None:
op.create_table(
"email_attachment",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("mail_message_id", sa.BigInteger(), nullable=False),
sa.Column("filename", sa.Text(), nullable=False),
sa.Column("content_type", sa.Text(), nullable=True),
@ -187,6 +188,9 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(
["mail_message_id"], ["mail_message.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["source_id"], ["source_item.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(

View File

@ -3,3 +3,5 @@ psycopg2-binary==2.9.9
pydantic==2.7.1
alembic==1.13.1
dotenv==1.1.0
voyageai==0.3.2
qdrant-client==1.9.0

View File

@ -1,4 +1,3 @@
celery==5.3.6
openai==1.25.0
pillow==10.3.0
qdrant-client==1.9.0

View File

@ -30,7 +30,7 @@ class SourceItem(Base):
mime_type = Column(Text)
mail_message = relationship("MailMessage", back_populates="source", uselist=False)
attachments = relationship("EmailAttachment", back_populates="source", cascade="all, delete-orphan", uselist=False)
# Add table-level constraint and indexes
__table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
@ -85,6 +85,7 @@ class EmailAttachment(Base):
__tablename__ = 'email_attachment'
id = Column(BigInteger, primary_key=True)
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
mail_message_id = Column(BigInteger, ForeignKey('mail_message.id', ondelete='CASCADE'), nullable=False)
filename = Column(Text, nullable=False)
content_type = Column(Text)
@ -94,6 +95,7 @@ class EmailAttachment(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now())
mail_message = relationship("MailMessage", back_populates="attachments")
source = relationship("SourceItem", back_populates="attachments")
def as_payload(self) -> dict:
return {

View File

@ -36,7 +36,7 @@ TYPES = {
}
def get_type(mime_type: str) -> str:
def get_modality(mime_type: str) -> str:
for type, mime_types in TYPES.items():
if mime_type in mime_types:
return type
@ -109,7 +109,7 @@ def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
yield chunk
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> list[str]:
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
"""
Split text into chunks respecting semantic boundaries while staying within token limits.
@ -149,10 +149,6 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
overlap_text.rfind("? ")
)
print(f"clean_break: {clean_break}")
print(f"overlap_text: {overlap_text}")
print(f"current: {current}")
if clean_break < 0:
yield current
current = ""
@ -160,18 +156,16 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
break_offset = -overlap_chars + clean_break + 1
chunk = current[break_offset:].strip()
print(f"chunk: {chunk}")
print(f"current: {current}")
yield current
current = chunk
if current:
yield current.strip()
def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
vo = voyageai.Client()
chunks = chunk_text(text, MAX_TOKENS)
return [vo.embed(chunk, model=model) for chunk in chunks]
return vo.embed(chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS), model=model)
def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
@ -182,4 +176,4 @@ def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n
if isinstance(content, bytes):
content = content.decode("utf-8")
return get_type(mime_type), embed_text(content, model, n_dimensions)
return get_modality(mime_type), embed_text(content, model, n_dimensions)

View File

@ -168,9 +168,18 @@ def process_attachment(attachment: Attachment, message: MailMessage) -> EmailAtt
logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}")
return None
source_item = SourceItem(
modality=embedding.get_modality(attachment["content_type"]),
sha256=hashlib.sha256(real_content if real_content else str(attachment).encode()).digest(),
tags=message.source.tags,
byte_length=attachment["size"],
mime_type=attachment["content_type"],
)
return EmailAttachment(
mail_message_id=message.id,
source=source_item,
filename=attachment["filename"],
mail_message=message,
content_type=attachment.get("content_type"),
size=attachment.get("size"),
content=content,
@ -242,7 +251,9 @@ def create_source_item(
db_session: Session,
message_hash: bytes,
account_tags: list[str],
raw_email_size: int,
raw_size: int,
modality: str = "mail",
mime_type: str = "message/rfc822",
) -> SourceItem:
"""
Create a new source item record.
@ -251,17 +262,17 @@ def create_source_item(
db_session: Database session
message_hash: SHA-256 hash of message
account_tags: Tags from the email account
raw_email_size: Size of raw email in bytes
raw_size: Size of raw email in bytes
Returns:
Newly created SourceItem
"""
source_item = SourceItem(
modality="mail",
modality=modality,
sha256=message_hash,
tags=account_tags,
byte_length=raw_email_size,
mime_type="message/rfc822",
byte_length=raw_size,
mime_type=mime_type,
embed_status="RAW"
)
db_session.add(source_item)
@ -271,7 +282,7 @@ def create_source_item(
def create_mail_message(
db_session: Session,
source_id: int,
source_item: SourceItem,
parsed_email: EmailMessage,
folder: str,
) -> MailMessage:
@ -288,7 +299,7 @@ def create_mail_message(
Newly created MailMessage
"""
mail_message = MailMessage(
source_id=source_id,
source=source_item,
message_id=parsed_email["message_id"],
subject=parsed_email["subject"],
sender=parsed_email["sender"],
@ -462,15 +473,17 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
def vectorize_email(email: MailMessage) -> list[float]:
qdrant_client = get_qdrant_client()
vector_id = uuid.uuid4()
chunks = embedding.embed_text(email.body_raw)
payloads = [email.as_payload()] * len(chunks)
vector_ids = [str(uuid.uuid4()) for _ in chunks]
upsert_vectors(
client=qdrant_client,
collection_name="mail",
ids=[str(vector_id)],
vectors=[embedding.embed_text(email.body_raw)],
payloads=[email.as_payload()],
ids=vector_ids,
vectors=chunks,
payloads=payloads,
)
vector_ids = [f"mail/{vector_id}"]
vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
embeds = defaultdict(list)
for attachment in email.attachments:
@ -478,8 +491,11 @@ def vectorize_email(email: MailMessage) -> list[float]:
content = pathlib.Path(attachment.file_path).read_bytes()
else:
content = attachment.content
collection, vector = embedding.embed(attachment.content_type, content)
embeds[collection].append((str(uuid.uuid4()), vector, attachment.as_payload()))
collection, vectors = embedding.embed(attachment.content_type, content)
attachment.source.vector_ids = vector_ids
embeds[collection].extend(
(str(uuid.uuid4()), vector, attachment.as_payload()) for vector in vectors
)
for collection, embeds in embeds.items():
ids, vectors, payloads = zip(*embeds)

View File

@ -63,8 +63,7 @@ def process_message(
return None
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
mail_message = create_mail_message(db, source_item.id, parsed_email, folder)
mail_message = create_mail_message(db, source_item, parsed_email, folder)
source_item.vector_ids = vectorize_email(mail_message)
db.commit()

View File

@ -1,5 +1,5 @@
import pytest
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, MAX_TOKENS, approx_token_count
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN
@pytest.mark.parametrize(

View File

@ -1,7 +1,9 @@
from unittest import mock
import pytest
from datetime import datetime, timedelta
from unittest.mock import patch
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common import embedding
from memory.workers.tasks.email import process_message
@ -36,6 +38,12 @@ VGhpcyBpcyBhIHRlc3QgYXR0YWNobWVudA==
--boundary123--"""
@pytest.fixture(autouse=True)
def mock_voyage_embed_text():
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
yield
@pytest.fixture
def test_email_account(db_session):
"""Create a test email account for integration testing."""

View File

@ -11,6 +11,7 @@ from unittest.mock import ANY, MagicMock, patch
import pytest
from memory.common.db.models import SourceItem, MailMessage, EmailAttachment, EmailAccount
from memory.common import settings
from memory.common import embedding
from memory.workers.email import (
compute_message_hash,
create_source_item,
@ -236,6 +237,7 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
}
message = MailMessage(
id=1,
source=SourceItem(tags=["test"]),
message_id=message_id,
sender="sender@example.com",
folder="INBOX",
@ -271,6 +273,7 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
}
message = MailMessage(
id=1,
source=SourceItem(tags=["test"]),
message_id=message_id,
sender="sender@example.com",
folder="INBOX",
@ -293,6 +296,7 @@ def test_process_attachment_write_error():
}
message = MailMessage(
id=1,
source=SourceItem(tags=["test"]),
message_id="<test@example.com>",
sender="sender@example.com",
folder="INBOX",
@ -340,6 +344,7 @@ def test_process_attachments_mixed():
]
message = MailMessage(
id=1,
source=SourceItem(tags=["test"]),
message_id="<test@example.com>",
sender="sender@example.com",
folder="INBOX",
@ -458,7 +463,7 @@ def test_create_source_item(db_session):
db_session=db_session,
message_hash=message_hash,
account_tags=account_tags,
raw_email_size=raw_email_size,
raw_size=raw_email_size,
)
# Verify the source item was created correctly
@ -549,7 +554,6 @@ def test_check_message_exists(
def test_create_mail_message(db_session):
source_item = SourceItem(
id=1,
modality="mail",
sha256=b"test_hash_bytes" + bytes(28),
tags=["test"],
@ -557,7 +561,6 @@ def test_create_mail_message(db_session):
)
db_session.add(source_item)
db_session.flush()
source_id = source_item.id
parsed_email = {
"message_id": "<test@example.com>",
"subject": "Test Subject",
@ -574,21 +577,22 @@ def test_create_mail_message(db_session):
# Call function
mail_message = create_mail_message(
db_session=db_session,
source_id=source_id,
source_item=source_item,
parsed_email=parsed_email,
folder=folder,
)
db_session.commit()
attachments = db_session.query(EmailAttachment).filter(EmailAttachment.mail_message_id == mail_message.id).all()
# Verify the mail message was created correctly
assert isinstance(mail_message, MailMessage)
assert mail_message.source_id == source_id
assert mail_message.source_id == source_item.id
assert mail_message.message_id == parsed_email["message_id"]
assert mail_message.subject == parsed_email["subject"]
assert mail_message.sender == parsed_email["sender"]
assert mail_message.recipients == parsed_email["recipients"]
assert mail_message.sent_at == parsed_email["sent_at"]
assert mail_message.sent_at.isoformat()[:-6] == parsed_email["sent_at"].isoformat()
assert mail_message.body_raw == parsed_email["body"]
assert mail_message.attachments == attachments
@ -670,22 +674,15 @@ def test_process_folder_error(email_provider):
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
source_item = SourceItem(
id=1,
mail_message = MailMessage(
source=SourceItem(
modality="mail",
sha256=b"test_hash" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
)
db_session.add(source_item)
db_session.flush()
# Create mail message
mail_message = MailMessage(
id=1,
source_id=1,
),
message_id="<test-vector@example.com>",
subject="Test Vectorization",
sender="sender@example.com",
@ -696,7 +693,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
db_session.add(mail_message)
db_session.flush()
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
vector_ids = vectorize_email(mail_message)
assert len(vector_ids) == 1
@ -704,22 +701,15 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
source_item = SourceItem(
id=2,
mail_message = MailMessage(
source=SourceItem(
modality="mail",
sha256=b"test_hash2" + bytes(24),
sha256=b"test_hash" + bytes(24),
tags=["test"],
byte_length=200,
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
)
db_session.add(source_item)
db_session.flush()
# Create mail message
mail_message = MailMessage(
id=2,
source_id=2,
),
message_id="<test-vector-attach@example.com>",
subject="Test Vectorization with Attachments",
sender="sender@example.com",
@ -732,33 +722,47 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
# Add two attachments - one with content and one with file_path
attachment1 = EmailAttachment(
id=1,
mail_message_id=mail_message.id,
filename="inline.txt",
content_type="text/plain",
size=100,
content=base64.b64encode(b"This is inline content"),
file_path=None,
source=SourceItem(
modality="doc",
sha256=b"test_hash1" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="text/plain",
embed_status="RAW",
),
)
file_path = mail_message.attachments_path / "stored.txt"
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_bytes(b"This is stored content")
attachment2 = EmailAttachment(
id=2,
mail_message_id=mail_message.id,
filename="stored.txt",
content_type="text/plain",
size=200,
content=None,
file_path=str(file_path),
source=SourceItem(
modality="doc",
sha256=b"test_hash2" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="text/plain",
embed_status="RAW",
),
)
db_session.add_all([attachment1, attachment2])
db_session.flush()
# Mock embedding functions but use real qdrant
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
# Call the function
vector_ids = vectorize_email(mail_message)