From a104a3211be57d1d5cbc5171f198e1bb378d956c Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 27 Apr 2025 22:24:30 +0200 Subject: [PATCH] add qdrant --- requirements-dev.txt | 3 +- src/memory/common/db/models.py | 26 ++ src/memory/common/embedding.py | 12 + src/memory/common/settings.py | 13 +- src/memory/workers/email.py | 32 ++- src/memory/workers/qdrant.py | 231 ++++++++++++++++++ src/memory/workers/tasks/email.py | 11 +- tests/conftest.py | 13 + .../memory/workers/tasks/test_email_tasks.py | 8 +- tests/memory/workers/test_email.py | 111 +++++++++ tests/memory/workers/test_qdrant.py | 82 +++++++ 11 files changed, 532 insertions(+), 10 deletions(-) create mode 100644 src/memory/common/embedding.py create mode 100644 src/memory/workers/qdrant.py create mode 100644 tests/memory/workers/test_qdrant.py diff --git a/requirements-dev.txt b/requirements-dev.txt index 4220417..69a11ff 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,5 @@ pytest==7.4.4 pytest-cov==4.1.0 black==23.12.1 mypy==1.8.0 -isort==5.13.2 \ No newline at end of file +isort==5.13.2 +testcontainers[qdrant]==4.10.0 \ No newline at end of file diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index 57389ee..490c0bf 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -29,6 +29,8 @@ class SourceItem(Base): byte_length = Column(Integer) mime_type = Column(Text) + mail_message = relationship("MailMessage", back_populates="source", uselist=False) + # Add table-level constraint and indexes __table_args__ = ( CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"), @@ -53,11 +55,24 @@ class MailMessage(Base): tsv = Column(TSVECTOR) attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan") + source = relationship("SourceItem", back_populates="mail_message") @property def attachments_path(self) -> Path: return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX') + def as_payload(self) -> dict: + return { + "source_id": self.source_id, + "message_id": self.message_id, + "subject": self.subject, + "sender": self.sender, + "recipients": self.recipients, + "folder": self.folder, + "tags": self.source.tags, + "date": self.sent_at and self.sent_at.isoformat() or None, + } + # Add indexes __table_args__ = ( Index('mail_sent_idx', 'sent_at'), @@ -80,6 +95,17 @@ class EmailAttachment(Base): mail_message = relationship("MailMessage", back_populates="attachments") + def as_payload(self) -> dict: + return { + "filename": self.filename, + "content_type": self.content_type, + "size": self.size, + "created_at": self.created_at and self.created_at.isoformat() or None, + "mail_message_id": self.mail_message_id, + "source_id": self.mail_message.source_id, + "tags": self.mail_message.source.tags, + } + # Add indexes __table_args__ = ( Index('email_attachment_message_idx', 'mail_message_id'), diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py new file mode 100644 index 0000000..dce5715 --- /dev/null +++ b/src/memory/common/embedding.py @@ -0,0 +1,12 @@ +def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]: + """ + Embed a text using OpenAI's API. + """ + return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector + + +def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]: + """ + Embed a file using OpenAI's API. + """ + return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index d728e0a..a130a74 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -4,6 +4,9 @@ from dotenv import load_dotenv load_dotenv() +def boolean_env(key: str, default: bool = False) -> bool: + return os.getenv(key, "0").lower() in ("1", "true", "yes") + DB_USER = os.getenv("DB_USER", "kb") DB_PASSWORD = os.getenv("DB_PASSWORD", "kb") @@ -21,4 +24,12 @@ FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True) # Maximum attachment size to store directly in the database (10MB) -MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024)) \ No newline at end of file +MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024)) + +# Qdrant settings +QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant") +QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) +QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334")) +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None) +QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False) +QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60")) \ No newline at end of file diff --git a/src/memory/workers/email.py b/src/memory/workers/email.py index 4dad3d0..14cff7b 100644 --- a/src/memory/workers/email.py +++ b/src/memory/workers/email.py @@ -13,7 +13,8 @@ import pathlib from sqlalchemy.orm import Session from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment -from memory.common import settings +from memory.common import settings, embedding +from memory.workers.qdrant import get_qdrant_client, upsert_vectors logger = logging.getLogger(__name__) @@ -456,3 +457,32 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None, conn.logout() except Exception as e: logger.error(f"Error logging out from {account.imap_server}: {str(e)}") + + +def vectorize_email(email: MailMessage) -> list[float]: + qdrant_client = get_qdrant_client() + + vector_ids = [str(uuid.uuid4())] + vectors = [embedding.embed_text(email.body_raw)] + payloads = [email.as_payload()] + + for attachment in email.attachments: + vector_ids.append(str(uuid.uuid4())) + payloads.append(attachment.as_payload()) + + if attachment.file_path: + vector = embedding.embed_file(attachment.file_path) + else: + vector = embedding.embed_text(attachment.content) + vectors.append(vector) + + upsert_vectors( + client=qdrant_client, + collection_name="mail", + ids=vector_ids, + vectors=vectors, + payloads=payloads, + ) + + logger.info(f"Stored embedding for message {email.message_id}") + return vector_ids diff --git a/src/memory/workers/qdrant.py b/src/memory/workers/qdrant.py new file mode 100644 index 0000000..5d07ef3 --- /dev/null +++ b/src/memory/workers/qdrant.py @@ -0,0 +1,231 @@ +import logging +from typing import Any, Literal, TypedDict, cast + +import qdrant_client +from qdrant_client.http import models as qdrant_models +from qdrant_client.http.exceptions import UnexpectedResponse +from memory.common import settings + +logger = logging.getLogger(__name__) + +# Type of distance metric +DistanceType = Literal["Cosine", "Dot", "Euclidean"] +Vector = list[float] + +class Collection(TypedDict): + dimension: int + distance: DistanceType + on_disk: bool + shards: int + + +DEFAULT_COLLECTIONS: dict[str, Collection] = { + "mail": {"dimension": 1536, "distance": "Cosine"}, + # "chat": {"dimension": 1536, "distance": "Cosine"}, + # "git": {"dimension": 1536, "distance": "Cosine"}, + # "photo": {"dimension": 512, "distance": "Cosine"}, + # "book": {"dimension": 1536, "distance": "Cosine"}, + # "blog": {"dimension": 1536, "distance": "Cosine"}, + # "doc": {"dimension": 1536, "distance": "Cosine"}, +} + + +def get_qdrant_client() -> qdrant_client.QdrantClient: + """Create and return a Qdrant client using environment configuration.""" + logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}") + + return qdrant_client.QdrantClient( + host=settings.QDRANT_HOST, + port=settings.QDRANT_PORT, + grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None, + prefer_grpc=settings.QDRANT_PREFER_GRPC, + api_key=settings.QDRANT_API_KEY, + timeout=settings.QDRANT_TIMEOUT, + ) + + +def ensure_collection_exists( + client: qdrant_client.QdrantClient, + collection_name: str, + dimension: int, + distance: DistanceType = "Cosine", + on_disk: bool = True, + shards: int = 1, +) -> bool: + """ + Ensure a collection exists with the specified parameters. + + Args: + client: Qdrant client + collection_name: Name of the collection + dimension: Vector dimension + distance: Distance metric (Cosine, Dot, Euclidean) + on_disk: Whether to store vectors on disk + shards: Number of shards for the collection + + Returns: + True if the collection was created, False if it already existed + """ + try: + client.get_collection(collection_name) + logger.debug(f"Collection {collection_name} already exists") + return False + except (UnexpectedResponse, ValueError): + logger.info(f"Creating collection {collection_name} with dimension {dimension}") + client.create_collection( + collection_name=collection_name, + vectors_config=qdrant_models.VectorParams( + size=dimension, + distance=cast(qdrant_models.Distance, distance), + ), + on_disk_payload=on_disk, + shard_number=shards, + ) + + # Create common payload indexes + client.create_payload_index( + collection_name=collection_name, + field_name="tags", + field_schema=qdrant_models.PayloadSchemaType.KEYWORD, + ) + + return True + + +def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None: + """ + Initialize all required collections in Qdrant. + + Args: + client: Qdrant client + collections: Dictionary mapping collection names to their parameters. + If None, defaults to the DEFAULT_COLLECTIONS. + """ + if collections is None: + collections = DEFAULT_COLLECTIONS + + for name, params in collections.items(): + ensure_collection_exists( + client, + collection_name=name, + dimension=params["dimension"], + distance=params.get("distance", "Cosine"), + on_disk=params.get("on_disk", True), + shards=params.get("shards", 1), + ) + + +def setup_qdrant() -> qdrant_client.QdrantClient: + """Get a Qdrant client and initialize collections. + + Returns: + Configured Qdrant client + """ + client = get_qdrant_client() + initialize_collections(client) + return client + + +def upsert_vectors( + client: qdrant_client.QdrantClient, + collection_name: str, + ids: list[str], + vectors: list[Vector], + payloads: list[dict[str, Any]] = None, +) -> None: + """Upsert vectors into a collection. + + Args: + client: Qdrant client + collection_name: Name of the collection + ids: List of vector IDs (as strings) + vectors: List of vectors + payloads: List of payloads, one per vector + """ + if payloads is None: + payloads = [{} for _ in ids] + + points = [ + qdrant_models.PointStruct( + id=id_str, + vector=vector, + payload=payload, + ) + for id_str, vector, payload in zip(ids, vectors, payloads) + ] + + client.upsert( + collection_name=collection_name, + points=points, + ) + + logger.debug(f"Upserted {len(ids)} vectors into {collection_name}") + + +def search_vectors( + client: qdrant_client.QdrantClient, + collection_name: str, + query_vector: Vector, + filter_params: dict = None, + limit: int = 10, +) -> list[qdrant_models.ScoredPoint]: + """Search for similar vectors in a collection. + + Args: + client: Qdrant client + collection_name: Name of the collection + query_vector: Query vector + filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}}) + limit: Maximum number of results to return + + Returns: + List of scored points + """ + filter_obj = None + if filter_params: + filter_obj = qdrant_models.Filter(**filter_params) + + return client.search( + collection_name=collection_name, + query_vector=query_vector, + query_filter=filter_obj, + limit=limit, + ) + + +def delete_vectors( + client: qdrant_client.QdrantClient, + collection_name: str, + ids: list[str], +) -> None: + """ + Delete vectors from a collection. + + Args: + client: Qdrant client + collection_name: Name of the collection + ids: List of vector IDs to delete + """ + client.delete( + collection_name=collection_name, + points_selector=qdrant_models.PointIdsList( + points=ids, + ), + ) + + logger.debug(f"Deleted {len(ids)} vectors from {collection_name}") + + +def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict: + """ + Get information about a collection. + + Args: + client: Qdrant client + collection_name: Name of the collection + + Returns: + Dictionary with collection information + """ + info = client.get_collection(collection_name) + return info.model_dump() diff --git a/src/memory/workers/tasks/email.py b/src/memory/workers/tasks/email.py index 654d4a5..a37ad16 100644 --- a/src/memory/workers/tasks/email.py +++ b/src/memory/workers/tasks/email.py @@ -12,6 +12,7 @@ from memory.workers.email import ( imap_connection, parse_email_message, process_folder, + vectorize_email, ) @@ -63,11 +64,15 @@ def process_message( source_item = create_source_item(db, message_hash, account.tags, len(raw_email)) - create_mail_message(db, source_item.id, parsed_email, folder) + mail_message = create_mail_message(db, source_item.id, parsed_email, folder) + source_item.vector_ids = vectorize_email(mail_message) db.commit() - - # TODO: Queue for embedding once that's implemented + + logger.info(f"Stored embedding for message {parsed_email['message_id']}") + logger.info("Vector IDs:") + for vector_id in source_item.vector_ids: + logger.info(f" - {vector_id}") return source_item.id diff --git a/tests/conftest.py b/tests/conftest.py index 099c411..91ecf90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,11 +6,14 @@ import uuid from pathlib import Path import pytest +import qdrant_client from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker +from testcontainers.qdrant import QdrantContainer from memory.common import settings from tests.providers.email_provider import MockEmailProvider +from memory.workers.qdrant import initialize_collections def get_test_db_name() -> str: @@ -197,3 +200,13 @@ def email_provider(): def mock_file_storage(tmp_path: Path): with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path): yield + + +@pytest.fixture +def qdrant(): + with QdrantContainer() as qdrant: + client = qdrant.get_client() + with patch.object(qdrant_client, "QdrantClient", return_value=client): + initialize_collections(client) + yield client + diff --git a/tests/memory/workers/tasks/test_email_tasks.py b/tests/memory/workers/tasks/test_email_tasks.py index 4503574..acd5abd 100644 --- a/tests/memory/workers/tasks/test_email_tasks.py +++ b/tests/memory/workers/tasks/test_email_tasks.py @@ -56,7 +56,7 @@ def test_email_account(db_session): return account -def test_process_simple_email(db_session, test_email_account): +def test_process_simple_email(db_session, test_email_account, qdrant): """Test processing a simple email message.""" source_id = process_message( account_id=test_email_account.id, @@ -85,7 +85,7 @@ def test_process_simple_email(db_session, test_email_account): assert mail_message.folder == "INBOX" -def test_process_email_with_attachment(db_session, test_email_account): +def test_process_email_with_attachment(db_session, test_email_account, qdrant): """Test processing a message with an attachment.""" source_id = process_message( account_id=test_email_account.id, @@ -116,7 +116,7 @@ def test_process_email_with_attachment(db_session, test_email_account): assert attachments[0].content is not None or attachments[0].file_path is not None -def test_process_empty_message(db_session, test_email_account): +def test_process_empty_message(db_session, test_email_account, qdrant): """Test processing an empty/invalid message.""" source_id = process_message( account_id=test_email_account.id, @@ -128,7 +128,7 @@ def test_process_empty_message(db_session, test_email_account): assert source_id is None -def test_process_duplicate_message(db_session, test_email_account): +def test_process_duplicate_message(db_session, test_email_account, qdrant): """Test that duplicate messages are detected and not stored again.""" # First call should succeed and create records source_id_1 = process_message( diff --git a/tests/memory/workers/test_email.py b/tests/memory/workers/test_email.py index 1ed4520..affefac 100644 --- a/tests/memory/workers/test_email.py +++ b/tests/memory/workers/test_email.py @@ -27,9 +27,21 @@ from memory.workers.email import ( process_folder, process_attachment, process_attachments, + vectorize_email, ) +@pytest.fixture +def mock_uuid4(): + i = 0 + def uuid4(): + nonlocal i + i += 1 + return f"00000000-0000-0000-0000-00000000000{i}" + + with patch("uuid.uuid4", side_effect=uuid4): + yield + # Use a simple counter to generate unique message IDs without calling make_msgid _msg_id_counter = 0 @@ -655,3 +667,102 @@ def test_process_folder_error(email_provider): email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor ) assert result == {"messages_found": 0, "new_messages": 0, "errors": 0} + + +def test_vectorize_email_basic(db_session, qdrant, mock_uuid4): + source_item = SourceItem( + id=1, + modality="mail", + sha256=b"test_hash" + bytes(24), + tags=["test"], + byte_length=100, + mime_type="message/rfc822", + embed_status="RAW", + ) + db_session.add(source_item) + db_session.flush() + + # Create mail message + mail_message = MailMessage( + id=1, + source_id=1, + message_id="", + subject="Test Vectorization", + sender="sender@example.com", + recipients=["recipient@example.com"], + body_raw="This is a test email for vectorization", + folder="INBOX", + ) + db_session.add(mail_message) + db_session.flush() + + with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536): + vector_ids = vectorize_email(mail_message) + + assert len(vector_ids) == 1 + assert vector_ids[0] == "00000000-0000-0000-0000-000000000001" + + +def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): + source_item = SourceItem( + id=2, + modality="mail", + sha256=b"test_hash2" + bytes(24), + tags=["test"], + byte_length=200, + mime_type="message/rfc822", + embed_status="RAW", + ) + db_session.add(source_item) + db_session.flush() + + # Create mail message + mail_message = MailMessage( + id=2, + source_id=2, + message_id="", + subject="Test Vectorization with Attachments", + sender="sender@example.com", + recipients=["recipient@example.com"], + body_raw="This is a test email with attachments", + folder="INBOX", + ) + db_session.add(mail_message) + db_session.flush() + + # Add two attachments - one with content and one with file_path + attachment1 = EmailAttachment( + id=1, + mail_message_id=mail_message.id, + filename="inline.txt", + content_type="text/plain", + size=100, + content=base64.b64encode(b"This is inline content"), + file_path=None, + ) + + attachment2 = EmailAttachment( + id=2, + mail_message_id=mail_message.id, + filename="stored.txt", + content_type="text/plain", + size=200, + content=None, + file_path="/path/to/stored.txt", + ) + + db_session.add_all([attachment1, attachment2]) + db_session.flush() + + # Mock embedding functions but use real qdrant + with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536), \ + patch("memory.common.embedding.embed_file", return_value=[0.7] * 1536): + + # Call the function + vector_ids = vectorize_email(mail_message) + + # Verify results + assert len(vector_ids) == 3 + assert vector_ids[0] == "00000000-0000-0000-0000-000000000001" + assert vector_ids[1] == "00000000-0000-0000-0000-000000000002" + assert vector_ids[2] == "00000000-0000-0000-0000-000000000003" diff --git a/tests/memory/workers/test_qdrant.py b/tests/memory/workers/test_qdrant.py new file mode 100644 index 0000000..a4f2bff --- /dev/null +++ b/tests/memory/workers/test_qdrant.py @@ -0,0 +1,82 @@ +import pytest +from unittest.mock import MagicMock, patch + +import qdrant_client +from qdrant_client.http import models as qdrant_models +from qdrant_client.http.exceptions import UnexpectedResponse + +from memory.workers.qdrant import ( + DEFAULT_COLLECTIONS, + ensure_collection_exists, + initialize_collections, + upsert_vectors, + search_vectors, + delete_vectors, +) + + +@pytest.fixture +def mock_qdrant_client(): + with patch.object(qdrant_client, "QdrantClient", return_value=MagicMock()) as mock_client: + yield mock_client + + +def test_ensure_collection_exists_existing(mock_qdrant_client): + mock_qdrant_client.get_collection.return_value = {} + assert not ensure_collection_exists(mock_qdrant_client, "test_collection", 128) + + mock_qdrant_client.get_collection.assert_called_once_with("test_collection") + mock_qdrant_client.create_collection.assert_not_called() + + +def test_ensure_collection_exists_new(mock_qdrant_client): + mock_qdrant_client.get_collection.side_effect = UnexpectedResponse( + status_code=404, reason_phrase='asd', content=b'asd', headers=None + ) + + assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128) + + mock_qdrant_client.get_collection.assert_called_once_with("test_collection") + mock_qdrant_client.create_collection.assert_called_once() + mock_qdrant_client.create_payload_index.assert_called_once() + + +def test_initialize_collections(mock_qdrant_client): + initialize_collections(mock_qdrant_client) + + assert mock_qdrant_client.get_collection.call_count == len(DEFAULT_COLLECTIONS) + + +def test_upsert_vectors(mock_qdrant_client): + ids = ["1", "2"] + vectors = [[0.1, 0.2], [0.3, 0.4]] + payloads = [{"tag": "test1"}, {"tag": "test2"}] + + upsert_vectors(mock_qdrant_client, "test_collection", ids, vectors, payloads) + + mock_qdrant_client.upsert.assert_called_once() + args, kwargs = mock_qdrant_client.upsert.call_args + assert kwargs["collection_name"] == "test_collection" + assert len(kwargs["points"]) == 2 + + # Check points were created correctly + points = kwargs["points"] + assert points[0].id == "1" + assert points[0].vector == [0.1, 0.2] + assert points[0].payload == {"tag": "test1"} + assert points[1].id == "2" + assert points[1].vector == [0.3, 0.4] + assert points[1].payload == {"tag": "test2"} + + +def test_delete_vectors(mock_qdrant_client): + ids = ["1", "2"] + + delete_vectors(mock_qdrant_client, "test_collection", ids) + + mock_qdrant_client.delete.assert_called_once() + args, kwargs = mock_qdrant_client.delete.call_args + + assert kwargs["collection_name"] == "test_collection" + assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList) + assert kwargs["points_selector"].points == ids \ No newline at end of file