mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
embeding tests
This commit is contained in:
parent
2d2f37536a
commit
869e5ac6b4
@ -172,6 +172,7 @@ def upgrade() -> None:
|
||||
op.create_table(
|
||||
"email_attachment",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("source_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("mail_message_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("filename", sa.Text(), nullable=False),
|
||||
sa.Column("content_type", sa.Text(), nullable=True),
|
||||
@ -187,6 +188,9 @@ def upgrade() -> None:
|
||||
sa.ForeignKeyConstraint(
|
||||
["mail_message_id"], ["mail_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_id"], ["source_item.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
|
@ -2,4 +2,6 @@ sqlalchemy==2.0.30
|
||||
psycopg2-binary==2.9.9
|
||||
pydantic==2.7.1
|
||||
alembic==1.13.1
|
||||
dotenv==1.1.0
|
||||
dotenv==1.1.0
|
||||
voyageai==0.3.2
|
||||
qdrant-client==1.9.0
|
@ -1,4 +1,3 @@
|
||||
celery==5.3.6
|
||||
openai==1.25.0
|
||||
pillow==10.3.0
|
||||
qdrant-client==1.9.0
|
@ -30,7 +30,7 @@ class SourceItem(Base):
|
||||
mime_type = Column(Text)
|
||||
|
||||
mail_message = relationship("MailMessage", back_populates="source", uselist=False)
|
||||
|
||||
attachments = relationship("EmailAttachment", back_populates="source", cascade="all, delete-orphan", uselist=False)
|
||||
# Add table-level constraint and indexes
|
||||
__table_args__ = (
|
||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
@ -85,6 +85,7 @@ class EmailAttachment(Base):
|
||||
__tablename__ = 'email_attachment'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
source_id = Column(BigInteger, ForeignKey('source_item.id', ondelete='CASCADE'), nullable=False)
|
||||
mail_message_id = Column(BigInteger, ForeignKey('mail_message.id', ondelete='CASCADE'), nullable=False)
|
||||
filename = Column(Text, nullable=False)
|
||||
content_type = Column(Text)
|
||||
@ -94,6 +95,7 @@ class EmailAttachment(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
mail_message = relationship("MailMessage", back_populates="attachments")
|
||||
source = relationship("SourceItem", back_populates="attachments")
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
|
@ -36,7 +36,7 @@ TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def get_type(mime_type: str) -> str:
|
||||
def get_modality(mime_type: str) -> str:
|
||||
for type, mime_types in TYPES.items():
|
||||
if mime_type in mime_types:
|
||||
return type
|
||||
@ -109,7 +109,7 @@ def yield_spans(text: str, max_tokens: int = MAX_TOKENS) -> Iterable[str]:
|
||||
yield chunk
|
||||
|
||||
|
||||
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> list[str]:
|
||||
def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_TOKENS) -> Iterable[str]:
|
||||
"""
|
||||
Split text into chunks respecting semantic boundaries while staying within token limits.
|
||||
|
||||
@ -149,10 +149,6 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
|
||||
overlap_text.rfind("? ")
|
||||
)
|
||||
|
||||
print(f"clean_break: {clean_break}")
|
||||
print(f"overlap_text: {overlap_text}")
|
||||
print(f"current: {current}")
|
||||
|
||||
if clean_break < 0:
|
||||
yield current
|
||||
current = ""
|
||||
@ -160,18 +156,16 @@ def chunk_text(text: str, max_tokens: int = MAX_TOKENS, overlap: int = OVERLAP_T
|
||||
|
||||
break_offset = -overlap_chars + clean_break + 1
|
||||
chunk = current[break_offset:].strip()
|
||||
print(f"chunk: {chunk}")
|
||||
print(f"current: {current}")
|
||||
yield current
|
||||
current = chunk
|
||||
|
||||
if current:
|
||||
yield current.strip()
|
||||
|
||||
|
||||
def embed_text(text: str, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
|
||||
vo = voyageai.Client()
|
||||
chunks = chunk_text(text, MAX_TOKENS)
|
||||
return [vo.embed(chunk, model=model) for chunk in chunks]
|
||||
return vo.embed(chunk_text(text, MAX_TOKENS, OVERLAP_TOKENS), model=model)
|
||||
|
||||
|
||||
def embed_file(file_path: pathlib.Path, model: str = "voyage-3-large", n_dimensions: int = 1536) -> list[Vector]:
|
||||
@ -182,4 +176,4 @@ def embed(mime_type: str, content: bytes | str, model: str = "voyage-3-large", n
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return get_type(mime_type), embed_text(content, model, n_dimensions)
|
||||
return get_modality(mime_type), embed_text(content, model, n_dimensions)
|
||||
|
@ -168,9 +168,18 @@ def process_attachment(attachment: Attachment, message: MailMessage) -> EmailAtt
|
||||
logger.error(f"Failed to save attachment {safe_filename} to disk: {str(e)}")
|
||||
return None
|
||||
|
||||
source_item = SourceItem(
|
||||
modality=embedding.get_modality(attachment["content_type"]),
|
||||
sha256=hashlib.sha256(real_content if real_content else str(attachment).encode()).digest(),
|
||||
tags=message.source.tags,
|
||||
byte_length=attachment["size"],
|
||||
mime_type=attachment["content_type"],
|
||||
)
|
||||
|
||||
return EmailAttachment(
|
||||
mail_message_id=message.id,
|
||||
source=source_item,
|
||||
filename=attachment["filename"],
|
||||
mail_message=message,
|
||||
content_type=attachment.get("content_type"),
|
||||
size=attachment.get("size"),
|
||||
content=content,
|
||||
@ -242,7 +251,9 @@ def create_source_item(
|
||||
db_session: Session,
|
||||
message_hash: bytes,
|
||||
account_tags: list[str],
|
||||
raw_email_size: int,
|
||||
raw_size: int,
|
||||
modality: str = "mail",
|
||||
mime_type: str = "message/rfc822",
|
||||
) -> SourceItem:
|
||||
"""
|
||||
Create a new source item record.
|
||||
@ -251,17 +262,17 @@ def create_source_item(
|
||||
db_session: Database session
|
||||
message_hash: SHA-256 hash of message
|
||||
account_tags: Tags from the email account
|
||||
raw_email_size: Size of raw email in bytes
|
||||
raw_size: Size of raw email in bytes
|
||||
|
||||
Returns:
|
||||
Newly created SourceItem
|
||||
"""
|
||||
source_item = SourceItem(
|
||||
modality="mail",
|
||||
modality=modality,
|
||||
sha256=message_hash,
|
||||
tags=account_tags,
|
||||
byte_length=raw_email_size,
|
||||
mime_type="message/rfc822",
|
||||
byte_length=raw_size,
|
||||
mime_type=mime_type,
|
||||
embed_status="RAW"
|
||||
)
|
||||
db_session.add(source_item)
|
||||
@ -271,7 +282,7 @@ def create_source_item(
|
||||
|
||||
def create_mail_message(
|
||||
db_session: Session,
|
||||
source_id: int,
|
||||
source_item: SourceItem,
|
||||
parsed_email: EmailMessage,
|
||||
folder: str,
|
||||
) -> MailMessage:
|
||||
@ -288,7 +299,7 @@ def create_mail_message(
|
||||
Newly created MailMessage
|
||||
"""
|
||||
mail_message = MailMessage(
|
||||
source_id=source_id,
|
||||
source=source_item,
|
||||
message_id=parsed_email["message_id"],
|
||||
subject=parsed_email["subject"],
|
||||
sender=parsed_email["sender"],
|
||||
@ -462,15 +473,17 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
||||
def vectorize_email(email: MailMessage) -> list[float]:
|
||||
qdrant_client = get_qdrant_client()
|
||||
|
||||
vector_id = uuid.uuid4()
|
||||
chunks = embedding.embed_text(email.body_raw)
|
||||
payloads = [email.as_payload()] * len(chunks)
|
||||
vector_ids = [str(uuid.uuid4()) for _ in chunks]
|
||||
upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name="mail",
|
||||
ids=[str(vector_id)],
|
||||
vectors=[embedding.embed_text(email.body_raw)],
|
||||
payloads=[email.as_payload()],
|
||||
ids=vector_ids,
|
||||
vectors=chunks,
|
||||
payloads=payloads,
|
||||
)
|
||||
vector_ids = [f"mail/{vector_id}"]
|
||||
vector_ids = [f"mail/{vector_id}" for vector_id in vector_ids]
|
||||
|
||||
embeds = defaultdict(list)
|
||||
for attachment in email.attachments:
|
||||
@ -478,8 +491,11 @@ def vectorize_email(email: MailMessage) -> list[float]:
|
||||
content = pathlib.Path(attachment.file_path).read_bytes()
|
||||
else:
|
||||
content = attachment.content
|
||||
collection, vector = embedding.embed(attachment.content_type, content)
|
||||
embeds[collection].append((str(uuid.uuid4()), vector, attachment.as_payload()))
|
||||
collection, vectors = embedding.embed(attachment.content_type, content)
|
||||
attachment.source.vector_ids = vector_ids
|
||||
embeds[collection].extend(
|
||||
(str(uuid.uuid4()), vector, attachment.as_payload()) for vector in vectors
|
||||
)
|
||||
|
||||
for collection, embeds in embeds.items():
|
||||
ids, vectors, payloads = zip(*embeds)
|
||||
|
@ -63,8 +63,7 @@ def process_message(
|
||||
return None
|
||||
|
||||
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
|
||||
|
||||
mail_message = create_mail_message(db, source_item.id, parsed_email, folder)
|
||||
mail_message = create_mail_message(db, source_item, parsed_email, folder)
|
||||
|
||||
source_item.vector_ids = vectorize_email(mail_message)
|
||||
db.commit()
|
||||
|
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN, MAX_TOKENS, approx_token_count
|
||||
from memory.common.embedding import yield_word_chunks, yield_spans, chunk_text, CHARS_PER_TOKEN
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1,7 +1,9 @@
|
||||
from unittest import mock
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from unittest.mock import patch
|
||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||
from memory.common import embedding
|
||||
from memory.workers.tasks.email import process_message
|
||||
|
||||
|
||||
@ -36,6 +38,12 @@ VGhpcyBpcyBhIHRlc3QgYXR0YWNobWVudA==
|
||||
--boundary123--"""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_voyage_embed_text():
|
||||
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_email_account(db_session):
|
||||
"""Create a test email account for integration testing."""
|
||||
|
@ -11,6 +11,7 @@ from unittest.mock import ANY, MagicMock, patch
|
||||
import pytest
|
||||
from memory.common.db.models import SourceItem, MailMessage, EmailAttachment, EmailAccount
|
||||
from memory.common import settings
|
||||
from memory.common import embedding
|
||||
from memory.workers.email import (
|
||||
compute_message_hash,
|
||||
create_source_item,
|
||||
@ -236,6 +237,7 @@ def test_process_attachment_inline(attachment_size, max_inline_size, message_id)
|
||||
}
|
||||
message = MailMessage(
|
||||
id=1,
|
||||
source=SourceItem(tags=["test"]),
|
||||
message_id=message_id,
|
||||
sender="sender@example.com",
|
||||
folder="INBOX",
|
||||
@ -271,6 +273,7 @@ def test_process_attachment_disk(attachment_size, max_inline_size, message_id):
|
||||
}
|
||||
message = MailMessage(
|
||||
id=1,
|
||||
source=SourceItem(tags=["test"]),
|
||||
message_id=message_id,
|
||||
sender="sender@example.com",
|
||||
folder="INBOX",
|
||||
@ -293,6 +296,7 @@ def test_process_attachment_write_error():
|
||||
}
|
||||
message = MailMessage(
|
||||
id=1,
|
||||
source=SourceItem(tags=["test"]),
|
||||
message_id="<test@example.com>",
|
||||
sender="sender@example.com",
|
||||
folder="INBOX",
|
||||
@ -340,6 +344,7 @@ def test_process_attachments_mixed():
|
||||
]
|
||||
message = MailMessage(
|
||||
id=1,
|
||||
source=SourceItem(tags=["test"]),
|
||||
message_id="<test@example.com>",
|
||||
sender="sender@example.com",
|
||||
folder="INBOX",
|
||||
@ -458,7 +463,7 @@ def test_create_source_item(db_session):
|
||||
db_session=db_session,
|
||||
message_hash=message_hash,
|
||||
account_tags=account_tags,
|
||||
raw_email_size=raw_email_size,
|
||||
raw_size=raw_email_size,
|
||||
)
|
||||
|
||||
# Verify the source item was created correctly
|
||||
@ -549,7 +554,6 @@ def test_check_message_exists(
|
||||
|
||||
def test_create_mail_message(db_session):
|
||||
source_item = SourceItem(
|
||||
id=1,
|
||||
modality="mail",
|
||||
sha256=b"test_hash_bytes" + bytes(28),
|
||||
tags=["test"],
|
||||
@ -557,7 +561,6 @@ def test_create_mail_message(db_session):
|
||||
)
|
||||
db_session.add(source_item)
|
||||
db_session.flush()
|
||||
source_id = source_item.id
|
||||
parsed_email = {
|
||||
"message_id": "<test@example.com>",
|
||||
"subject": "Test Subject",
|
||||
@ -574,21 +577,22 @@ def test_create_mail_message(db_session):
|
||||
# Call function
|
||||
mail_message = create_mail_message(
|
||||
db_session=db_session,
|
||||
source_id=source_id,
|
||||
source_item=source_item,
|
||||
parsed_email=parsed_email,
|
||||
folder=folder,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
attachments = db_session.query(EmailAttachment).filter(EmailAttachment.mail_message_id == mail_message.id).all()
|
||||
|
||||
# Verify the mail message was created correctly
|
||||
assert isinstance(mail_message, MailMessage)
|
||||
assert mail_message.source_id == source_id
|
||||
assert mail_message.source_id == source_item.id
|
||||
assert mail_message.message_id == parsed_email["message_id"]
|
||||
assert mail_message.subject == parsed_email["subject"]
|
||||
assert mail_message.sender == parsed_email["sender"]
|
||||
assert mail_message.recipients == parsed_email["recipients"]
|
||||
assert mail_message.sent_at == parsed_email["sent_at"]
|
||||
assert mail_message.sent_at.isoformat()[:-6] == parsed_email["sent_at"].isoformat()
|
||||
assert mail_message.body_raw == parsed_email["body"]
|
||||
assert mail_message.attachments == attachments
|
||||
|
||||
@ -670,22 +674,15 @@ def test_process_folder_error(email_provider):
|
||||
|
||||
|
||||
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||
source_item = SourceItem(
|
||||
id=1,
|
||||
modality="mail",
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
)
|
||||
db_session.add(source_item)
|
||||
db_session.flush()
|
||||
|
||||
# Create mail message
|
||||
mail_message = MailMessage(
|
||||
id=1,
|
||||
source_id=1,
|
||||
source=SourceItem(
|
||||
modality="mail",
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
),
|
||||
message_id="<test-vector@example.com>",
|
||||
subject="Test Vectorization",
|
||||
sender="sender@example.com",
|
||||
@ -696,7 +693,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
|
||||
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
|
||||
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
|
||||
vector_ids = vectorize_email(mail_message)
|
||||
|
||||
assert len(vector_ids) == 1
|
||||
@ -704,22 +701,15 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||
|
||||
|
||||
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||
source_item = SourceItem(
|
||||
id=2,
|
||||
modality="mail",
|
||||
sha256=b"test_hash2" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=200,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
)
|
||||
db_session.add(source_item)
|
||||
db_session.flush()
|
||||
|
||||
# Create mail message
|
||||
mail_message = MailMessage(
|
||||
id=2,
|
||||
source_id=2,
|
||||
source=SourceItem(
|
||||
modality="mail",
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
),
|
||||
message_id="<test-vector-attach@example.com>",
|
||||
subject="Test Vectorization with Attachments",
|
||||
sender="sender@example.com",
|
||||
@ -732,33 +722,47 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||
|
||||
# Add two attachments - one with content and one with file_path
|
||||
attachment1 = EmailAttachment(
|
||||
id=1,
|
||||
mail_message_id=mail_message.id,
|
||||
filename="inline.txt",
|
||||
content_type="text/plain",
|
||||
size=100,
|
||||
content=base64.b64encode(b"This is inline content"),
|
||||
file_path=None,
|
||||
source=SourceItem(
|
||||
modality="doc",
|
||||
sha256=b"test_hash1" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="text/plain",
|
||||
embed_status="RAW",
|
||||
),
|
||||
)
|
||||
|
||||
file_path = mail_message.attachments_path / "stored.txt"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_bytes(b"This is stored content")
|
||||
attachment2 = EmailAttachment(
|
||||
id=2,
|
||||
mail_message_id=mail_message.id,
|
||||
filename="stored.txt",
|
||||
content_type="text/plain",
|
||||
size=200,
|
||||
content=None,
|
||||
file_path=str(file_path),
|
||||
source=SourceItem(
|
||||
modality="doc",
|
||||
sha256=b"test_hash2" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="text/plain",
|
||||
embed_status="RAW",
|
||||
),
|
||||
)
|
||||
|
||||
db_session.add_all([attachment1, attachment2])
|
||||
db_session.flush()
|
||||
|
||||
# Mock embedding functions but use real qdrant
|
||||
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
|
||||
with patch.object(embedding, "embed_text", return_value=[[0.1] * 1024]):
|
||||
# Call the function
|
||||
vector_ids = vectorize_email(mail_message)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user