add qdrant

This commit is contained in:
Daniel O'Connell 2025-04-27 22:24:30 +02:00
parent 889df318a1
commit a104a3211b
11 changed files with 532 additions and 10 deletions

View File

@ -3,3 +3,4 @@ pytest-cov==4.1.0
black==23.12.1 black==23.12.1
mypy==1.8.0 mypy==1.8.0
isort==5.13.2 isort==5.13.2
testcontainers[qdrant]==4.10.0

View File

@ -29,6 +29,8 @@ class SourceItem(Base):
byte_length = Column(Integer) byte_length = Column(Integer)
mime_type = Column(Text) mime_type = Column(Text)
mail_message = relationship("MailMessage", back_populates="source", uselist=False)
# Add table-level constraint and indexes # Add table-level constraint and indexes
__table_args__ = ( __table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"), CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
@ -53,11 +55,24 @@ class MailMessage(Base):
tsv = Column(TSVECTOR) tsv = Column(TSVECTOR)
attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan") attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan")
source = relationship("SourceItem", back_populates="mail_message")
@property @property
def attachments_path(self) -> Path: def attachments_path(self) -> Path:
return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX') return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX')
def as_payload(self) -> dict:
return {
"source_id": self.source_id,
"message_id": self.message_id,
"subject": self.subject,
"sender": self.sender,
"recipients": self.recipients,
"folder": self.folder,
"tags": self.source.tags,
"date": self.sent_at and self.sent_at.isoformat() or None,
}
# Add indexes # Add indexes
__table_args__ = ( __table_args__ = (
Index('mail_sent_idx', 'sent_at'), Index('mail_sent_idx', 'sent_at'),
@ -80,6 +95,17 @@ class EmailAttachment(Base):
mail_message = relationship("MailMessage", back_populates="attachments") mail_message = relationship("MailMessage", back_populates="attachments")
def as_payload(self) -> dict:
return {
"filename": self.filename,
"content_type": self.content_type,
"size": self.size,
"created_at": self.created_at and self.created_at.isoformat() or None,
"mail_message_id": self.mail_message_id,
"source_id": self.mail_message.source_id,
"tags": self.mail_message.source.tags,
}
# Add indexes # Add indexes
__table_args__ = ( __table_args__ = (
Index('email_attachment_message_idx', 'mail_message_id'), Index('email_attachment_message_idx', 'mail_message_id'),

View File

@ -0,0 +1,12 @@
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
"""
Embed a text using OpenAI's API.
"""
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
"""
Embed a file using OpenAI's API.
"""
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector

View File

@ -4,6 +4,9 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
def boolean_env(key: str, default: bool = False) -> bool:
return os.getenv(key, "0").lower() in ("1", "true", "yes")
DB_USER = os.getenv("DB_USER", "kb") DB_USER = os.getenv("DB_USER", "kb")
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb") DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
@ -22,3 +25,11 @@ FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
# Maximum attachment size to store directly in the database (10MB) # Maximum attachment size to store directly in the database (10MB)
MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024)) MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024))
# Qdrant settings
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334"))
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False)
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))

View File

@ -13,7 +13,8 @@ import pathlib
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common import settings from memory.common import settings, embedding
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -456,3 +457,32 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
conn.logout() conn.logout()
except Exception as e: except Exception as e:
logger.error(f"Error logging out from {account.imap_server}: {str(e)}") logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
def vectorize_email(email: MailMessage) -> list[float]:
qdrant_client = get_qdrant_client()
vector_ids = [str(uuid.uuid4())]
vectors = [embedding.embed_text(email.body_raw)]
payloads = [email.as_payload()]
for attachment in email.attachments:
vector_ids.append(str(uuid.uuid4()))
payloads.append(attachment.as_payload())
if attachment.file_path:
vector = embedding.embed_file(attachment.file_path)
else:
vector = embedding.embed_text(attachment.content)
vectors.append(vector)
upsert_vectors(
client=qdrant_client,
collection_name="mail",
ids=vector_ids,
vectors=vectors,
payloads=payloads,
)
logger.info(f"Stored embedding for message {email.message_id}")
return vector_ids

View File

@ -0,0 +1,231 @@
import logging
from typing import Any, Literal, TypedDict, cast
import qdrant_client
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse
from memory.common import settings
logger = logging.getLogger(__name__)
# Type of distance metric
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
on_disk: bool
shards: int
DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1536, "distance": "Cosine"},
# "chat": {"dimension": 1536, "distance": "Cosine"},
# "git": {"dimension": 1536, "distance": "Cosine"},
# "photo": {"dimension": 512, "distance": "Cosine"},
# "book": {"dimension": 1536, "distance": "Cosine"},
# "blog": {"dimension": 1536, "distance": "Cosine"},
# "doc": {"dimension": 1536, "distance": "Cosine"},
}
def get_qdrant_client() -> qdrant_client.QdrantClient:
"""Create and return a Qdrant client using environment configuration."""
logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}")
return qdrant_client.QdrantClient(
host=settings.QDRANT_HOST,
port=settings.QDRANT_PORT,
grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None,
prefer_grpc=settings.QDRANT_PREFER_GRPC,
api_key=settings.QDRANT_API_KEY,
timeout=settings.QDRANT_TIMEOUT,
)
def ensure_collection_exists(
client: qdrant_client.QdrantClient,
collection_name: str,
dimension: int,
distance: DistanceType = "Cosine",
on_disk: bool = True,
shards: int = 1,
) -> bool:
"""
Ensure a collection exists with the specified parameters.
Args:
client: Qdrant client
collection_name: Name of the collection
dimension: Vector dimension
distance: Distance metric (Cosine, Dot, Euclidean)
on_disk: Whether to store vectors on disk
shards: Number of shards for the collection
Returns:
True if the collection was created, False if it already existed
"""
try:
client.get_collection(collection_name)
logger.debug(f"Collection {collection_name} already exists")
return False
except (UnexpectedResponse, ValueError):
logger.info(f"Creating collection {collection_name} with dimension {dimension}")
client.create_collection(
collection_name=collection_name,
vectors_config=qdrant_models.VectorParams(
size=dimension,
distance=cast(qdrant_models.Distance, distance),
),
on_disk_payload=on_disk,
shard_number=shards,
)
# Create common payload indexes
client.create_payload_index(
collection_name=collection_name,
field_name="tags",
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
)
return True
def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None:
"""
Initialize all required collections in Qdrant.
Args:
client: Qdrant client
collections: Dictionary mapping collection names to their parameters.
If None, defaults to the DEFAULT_COLLECTIONS.
"""
if collections is None:
collections = DEFAULT_COLLECTIONS
for name, params in collections.items():
ensure_collection_exists(
client,
collection_name=name,
dimension=params["dimension"],
distance=params.get("distance", "Cosine"),
on_disk=params.get("on_disk", True),
shards=params.get("shards", 1),
)
def setup_qdrant() -> qdrant_client.QdrantClient:
"""Get a Qdrant client and initialize collections.
Returns:
Configured Qdrant client
"""
client = get_qdrant_client()
initialize_collections(client)
return client
def upsert_vectors(
client: qdrant_client.QdrantClient,
collection_name: str,
ids: list[str],
vectors: list[Vector],
payloads: list[dict[str, Any]] = None,
) -> None:
"""Upsert vectors into a collection.
Args:
client: Qdrant client
collection_name: Name of the collection
ids: List of vector IDs (as strings)
vectors: List of vectors
payloads: List of payloads, one per vector
"""
if payloads is None:
payloads = [{} for _ in ids]
points = [
qdrant_models.PointStruct(
id=id_str,
vector=vector,
payload=payload,
)
for id_str, vector, payload in zip(ids, vectors, payloads)
]
client.upsert(
collection_name=collection_name,
points=points,
)
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
def search_vectors(
client: qdrant_client.QdrantClient,
collection_name: str,
query_vector: Vector,
filter_params: dict = None,
limit: int = 10,
) -> list[qdrant_models.ScoredPoint]:
"""Search for similar vectors in a collection.
Args:
client: Qdrant client
collection_name: Name of the collection
query_vector: Query vector
filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}})
limit: Maximum number of results to return
Returns:
List of scored points
"""
filter_obj = None
if filter_params:
filter_obj = qdrant_models.Filter(**filter_params)
return client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=filter_obj,
limit=limit,
)
def delete_vectors(
client: qdrant_client.QdrantClient,
collection_name: str,
ids: list[str],
) -> None:
"""
Delete vectors from a collection.
Args:
client: Qdrant client
collection_name: Name of the collection
ids: List of vector IDs to delete
"""
client.delete(
collection_name=collection_name,
points_selector=qdrant_models.PointIdsList(
points=ids,
),
)
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict:
"""
Get information about a collection.
Args:
client: Qdrant client
collection_name: Name of the collection
Returns:
Dictionary with collection information
"""
info = client.get_collection(collection_name)
return info.model_dump()

View File

@ -12,6 +12,7 @@ from memory.workers.email import (
imap_connection, imap_connection,
parse_email_message, parse_email_message,
process_folder, process_folder,
vectorize_email,
) )
@ -63,11 +64,15 @@ def process_message(
source_item = create_source_item(db, message_hash, account.tags, len(raw_email)) source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
create_mail_message(db, source_item.id, parsed_email, folder) mail_message = create_mail_message(db, source_item.id, parsed_email, folder)
source_item.vector_ids = vectorize_email(mail_message)
db.commit() db.commit()
# TODO: Queue for embedding once that's implemented logger.info(f"Stored embedding for message {parsed_email['message_id']}")
logger.info("Vector IDs:")
for vector_id in source_item.vector_ids:
logger.info(f" - {vector_id}")
return source_item.id return source_item.id

View File

@ -6,11 +6,14 @@ import uuid
from pathlib import Path from pathlib import Path
import pytest import pytest
import qdrant_client
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from testcontainers.qdrant import QdrantContainer
from memory.common import settings from memory.common import settings
from tests.providers.email_provider import MockEmailProvider from tests.providers.email_provider import MockEmailProvider
from memory.workers.qdrant import initialize_collections
def get_test_db_name() -> str: def get_test_db_name() -> str:
@ -197,3 +200,13 @@ def email_provider():
def mock_file_storage(tmp_path: Path): def mock_file_storage(tmp_path: Path):
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path): with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
yield yield
@pytest.fixture
def qdrant():
with QdrantContainer() as qdrant:
client = qdrant.get_client()
with patch.object(qdrant_client, "QdrantClient", return_value=client):
initialize_collections(client)
yield client

View File

@ -56,7 +56,7 @@ def test_email_account(db_session):
return account return account
def test_process_simple_email(db_session, test_email_account): def test_process_simple_email(db_session, test_email_account, qdrant):
"""Test processing a simple email message.""" """Test processing a simple email message."""
source_id = process_message( source_id = process_message(
account_id=test_email_account.id, account_id=test_email_account.id,
@ -85,7 +85,7 @@ def test_process_simple_email(db_session, test_email_account):
assert mail_message.folder == "INBOX" assert mail_message.folder == "INBOX"
def test_process_email_with_attachment(db_session, test_email_account): def test_process_email_with_attachment(db_session, test_email_account, qdrant):
"""Test processing a message with an attachment.""" """Test processing a message with an attachment."""
source_id = process_message( source_id = process_message(
account_id=test_email_account.id, account_id=test_email_account.id,
@ -116,7 +116,7 @@ def test_process_email_with_attachment(db_session, test_email_account):
assert attachments[0].content is not None or attachments[0].file_path is not None assert attachments[0].content is not None or attachments[0].file_path is not None
def test_process_empty_message(db_session, test_email_account): def test_process_empty_message(db_session, test_email_account, qdrant):
"""Test processing an empty/invalid message.""" """Test processing an empty/invalid message."""
source_id = process_message( source_id = process_message(
account_id=test_email_account.id, account_id=test_email_account.id,
@ -128,7 +128,7 @@ def test_process_empty_message(db_session, test_email_account):
assert source_id is None assert source_id is None
def test_process_duplicate_message(db_session, test_email_account): def test_process_duplicate_message(db_session, test_email_account, qdrant):
"""Test that duplicate messages are detected and not stored again.""" """Test that duplicate messages are detected and not stored again."""
# First call should succeed and create records # First call should succeed and create records
source_id_1 = process_message( source_id_1 = process_message(

View File

@ -27,9 +27,21 @@ from memory.workers.email import (
process_folder, process_folder,
process_attachment, process_attachment,
process_attachments, process_attachments,
vectorize_email,
) )
@pytest.fixture
def mock_uuid4():
i = 0
def uuid4():
nonlocal i
i += 1
return f"00000000-0000-0000-0000-00000000000{i}"
with patch("uuid.uuid4", side_effect=uuid4):
yield
# Use a simple counter to generate unique message IDs without calling make_msgid # Use a simple counter to generate unique message IDs without calling make_msgid
_msg_id_counter = 0 _msg_id_counter = 0
@ -655,3 +667,102 @@ def test_process_folder_error(email_provider):
email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor
) )
assert result == {"messages_found": 0, "new_messages": 0, "errors": 0} assert result == {"messages_found": 0, "new_messages": 0, "errors": 0}
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
source_item = SourceItem(
id=1,
modality="mail",
sha256=b"test_hash" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
)
db_session.add(source_item)
db_session.flush()
# Create mail message
mail_message = MailMessage(
id=1,
source_id=1,
message_id="<test-vector@example.com>",
subject="Test Vectorization",
sender="sender@example.com",
recipients=["recipient@example.com"],
body_raw="This is a test email for vectorization",
folder="INBOX",
)
db_session.add(mail_message)
db_session.flush()
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
vector_ids = vectorize_email(mail_message)
assert len(vector_ids) == 1
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
source_item = SourceItem(
id=2,
modality="mail",
sha256=b"test_hash2" + bytes(24),
tags=["test"],
byte_length=200,
mime_type="message/rfc822",
embed_status="RAW",
)
db_session.add(source_item)
db_session.flush()
# Create mail message
mail_message = MailMessage(
id=2,
source_id=2,
message_id="<test-vector-attach@example.com>",
subject="Test Vectorization with Attachments",
sender="sender@example.com",
recipients=["recipient@example.com"],
body_raw="This is a test email with attachments",
folder="INBOX",
)
db_session.add(mail_message)
db_session.flush()
# Add two attachments - one with content and one with file_path
attachment1 = EmailAttachment(
id=1,
mail_message_id=mail_message.id,
filename="inline.txt",
content_type="text/plain",
size=100,
content=base64.b64encode(b"This is inline content"),
file_path=None,
)
attachment2 = EmailAttachment(
id=2,
mail_message_id=mail_message.id,
filename="stored.txt",
content_type="text/plain",
size=200,
content=None,
file_path="/path/to/stored.txt",
)
db_session.add_all([attachment1, attachment2])
db_session.flush()
# Mock embedding functions but use real qdrant
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536), \
patch("memory.common.embedding.embed_file", return_value=[0.7] * 1536):
# Call the function
vector_ids = vectorize_email(mail_message)
# Verify results
assert len(vector_ids) == 3
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
assert vector_ids[1] == "00000000-0000-0000-0000-000000000002"
assert vector_ids[2] == "00000000-0000-0000-0000-000000000003"

View File

@ -0,0 +1,82 @@
import pytest
from unittest.mock import MagicMock, patch
import qdrant_client
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse
from memory.workers.qdrant import (
DEFAULT_COLLECTIONS,
ensure_collection_exists,
initialize_collections,
upsert_vectors,
search_vectors,
delete_vectors,
)
@pytest.fixture
def mock_qdrant_client():
with patch.object(qdrant_client, "QdrantClient", return_value=MagicMock()) as mock_client:
yield mock_client
def test_ensure_collection_exists_existing(mock_qdrant_client):
mock_qdrant_client.get_collection.return_value = {}
assert not ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
mock_qdrant_client.create_collection.assert_not_called()
def test_ensure_collection_exists_new(mock_qdrant_client):
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
status_code=404, reason_phrase='asd', content=b'asd', headers=None
)
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
mock_qdrant_client.create_collection.assert_called_once()
mock_qdrant_client.create_payload_index.assert_called_once()
def test_initialize_collections(mock_qdrant_client):
initialize_collections(mock_qdrant_client)
assert mock_qdrant_client.get_collection.call_count == len(DEFAULT_COLLECTIONS)
def test_upsert_vectors(mock_qdrant_client):
ids = ["1", "2"]
vectors = [[0.1, 0.2], [0.3, 0.4]]
payloads = [{"tag": "test1"}, {"tag": "test2"}]
upsert_vectors(mock_qdrant_client, "test_collection", ids, vectors, payloads)
mock_qdrant_client.upsert.assert_called_once()
args, kwargs = mock_qdrant_client.upsert.call_args
assert kwargs["collection_name"] == "test_collection"
assert len(kwargs["points"]) == 2
# Check points were created correctly
points = kwargs["points"]
assert points[0].id == "1"
assert points[0].vector == [0.1, 0.2]
assert points[0].payload == {"tag": "test1"}
assert points[1].id == "2"
assert points[1].vector == [0.3, 0.4]
assert points[1].payload == {"tag": "test2"}
def test_delete_vectors(mock_qdrant_client):
ids = ["1", "2"]
delete_vectors(mock_qdrant_client, "test_collection", ids)
mock_qdrant_client.delete.assert_called_once()
args, kwargs = mock_qdrant_client.delete.call_args
assert kwargs["collection_name"] == "test_collection"
assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList)
assert kwargs["points_selector"].points == ids