add qdrant

This commit is contained in:
Daniel O'Connell 2025-04-27 22:24:30 +02:00
parent 889df318a1
commit a104a3211b
11 changed files with 532 additions and 10 deletions

View File

@ -3,3 +3,4 @@ pytest-cov==4.1.0
black==23.12.1
mypy==1.8.0
isort==5.13.2
testcontainers[qdrant]==4.10.0

View File

@ -29,6 +29,8 @@ class SourceItem(Base):
byte_length = Column(Integer)
mime_type = Column(Text)
mail_message = relationship("MailMessage", back_populates="source", uselist=False)
# Add table-level constraint and indexes
__table_args__ = (
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
@ -53,11 +55,24 @@ class MailMessage(Base):
tsv = Column(TSVECTOR)
attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan")
source = relationship("SourceItem", back_populates="mail_message")
@property
def attachments_path(self) -> Path:
return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX')
def as_payload(self) -> dict:
return {
"source_id": self.source_id,
"message_id": self.message_id,
"subject": self.subject,
"sender": self.sender,
"recipients": self.recipients,
"folder": self.folder,
"tags": self.source.tags,
"date": self.sent_at and self.sent_at.isoformat() or None,
}
# Add indexes
__table_args__ = (
Index('mail_sent_idx', 'sent_at'),
@ -80,6 +95,17 @@ class EmailAttachment(Base):
mail_message = relationship("MailMessage", back_populates="attachments")
def as_payload(self) -> dict:
return {
"filename": self.filename,
"content_type": self.content_type,
"size": self.size,
"created_at": self.created_at and self.created_at.isoformat() or None,
"mail_message_id": self.mail_message_id,
"source_id": self.mail_message.source_id,
"tags": self.mail_message.source.tags,
}
# Add indexes
__table_args__ = (
Index('email_attachment_message_idx', 'mail_message_id'),

View File

@ -0,0 +1,12 @@
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
"""
Embed a text using OpenAI's API.
"""
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
"""
Embed a file using OpenAI's API.
"""
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector

View File

@ -4,6 +4,9 @@ from dotenv import load_dotenv
load_dotenv()
def boolean_env(key: str, default: bool = False) -> bool:
return os.getenv(key, "0").lower() in ("1", "true", "yes")
DB_USER = os.getenv("DB_USER", "kb")
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
@ -22,3 +25,11 @@ FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
# Maximum attachment size to store directly in the database (10MB)
MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024))
# Qdrant settings
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334"))
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False)
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))

View File

@ -13,7 +13,8 @@ import pathlib
from sqlalchemy.orm import Session
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common import settings
from memory.common import settings, embedding
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
logger = logging.getLogger(__name__)
@ -456,3 +457,32 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
conn.logout()
except Exception as e:
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
def vectorize_email(email: MailMessage) -> list[float]:
qdrant_client = get_qdrant_client()
vector_ids = [str(uuid.uuid4())]
vectors = [embedding.embed_text(email.body_raw)]
payloads = [email.as_payload()]
for attachment in email.attachments:
vector_ids.append(str(uuid.uuid4()))
payloads.append(attachment.as_payload())
if attachment.file_path:
vector = embedding.embed_file(attachment.file_path)
else:
vector = embedding.embed_text(attachment.content)
vectors.append(vector)
upsert_vectors(
client=qdrant_client,
collection_name="mail",
ids=vector_ids,
vectors=vectors,
payloads=payloads,
)
logger.info(f"Stored embedding for message {email.message_id}")
return vector_ids

View File

@ -0,0 +1,231 @@
import logging
from typing import Any, Literal, TypedDict, cast
import qdrant_client
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse
from memory.common import settings
logger = logging.getLogger(__name__)
# Type of distance metric
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
on_disk: bool
shards: int
DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1536, "distance": "Cosine"},
# "chat": {"dimension": 1536, "distance": "Cosine"},
# "git": {"dimension": 1536, "distance": "Cosine"},
# "photo": {"dimension": 512, "distance": "Cosine"},
# "book": {"dimension": 1536, "distance": "Cosine"},
# "blog": {"dimension": 1536, "distance": "Cosine"},
# "doc": {"dimension": 1536, "distance": "Cosine"},
}
def get_qdrant_client() -> qdrant_client.QdrantClient:
"""Create and return a Qdrant client using environment configuration."""
logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}")
return qdrant_client.QdrantClient(
host=settings.QDRANT_HOST,
port=settings.QDRANT_PORT,
grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None,
prefer_grpc=settings.QDRANT_PREFER_GRPC,
api_key=settings.QDRANT_API_KEY,
timeout=settings.QDRANT_TIMEOUT,
)
def ensure_collection_exists(
client: qdrant_client.QdrantClient,
collection_name: str,
dimension: int,
distance: DistanceType = "Cosine",
on_disk: bool = True,
shards: int = 1,
) -> bool:
"""
Ensure a collection exists with the specified parameters.
Args:
client: Qdrant client
collection_name: Name of the collection
dimension: Vector dimension
distance: Distance metric (Cosine, Dot, Euclidean)
on_disk: Whether to store vectors on disk
shards: Number of shards for the collection
Returns:
True if the collection was created, False if it already existed
"""
try:
client.get_collection(collection_name)
logger.debug(f"Collection {collection_name} already exists")
return False
except (UnexpectedResponse, ValueError):
logger.info(f"Creating collection {collection_name} with dimension {dimension}")
client.create_collection(
collection_name=collection_name,
vectors_config=qdrant_models.VectorParams(
size=dimension,
distance=cast(qdrant_models.Distance, distance),
),
on_disk_payload=on_disk,
shard_number=shards,
)
# Create common payload indexes
client.create_payload_index(
collection_name=collection_name,
field_name="tags",
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
)
return True
def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None:
"""
Initialize all required collections in Qdrant.
Args:
client: Qdrant client
collections: Dictionary mapping collection names to their parameters.
If None, defaults to the DEFAULT_COLLECTIONS.
"""
if collections is None:
collections = DEFAULT_COLLECTIONS
for name, params in collections.items():
ensure_collection_exists(
client,
collection_name=name,
dimension=params["dimension"],
distance=params.get("distance", "Cosine"),
on_disk=params.get("on_disk", True),
shards=params.get("shards", 1),
)
def setup_qdrant() -> qdrant_client.QdrantClient:
"""Get a Qdrant client and initialize collections.
Returns:
Configured Qdrant client
"""
client = get_qdrant_client()
initialize_collections(client)
return client
def upsert_vectors(
client: qdrant_client.QdrantClient,
collection_name: str,
ids: list[str],
vectors: list[Vector],
payloads: list[dict[str, Any]] = None,
) -> None:
"""Upsert vectors into a collection.
Args:
client: Qdrant client
collection_name: Name of the collection
ids: List of vector IDs (as strings)
vectors: List of vectors
payloads: List of payloads, one per vector
"""
if payloads is None:
payloads = [{} for _ in ids]
points = [
qdrant_models.PointStruct(
id=id_str,
vector=vector,
payload=payload,
)
for id_str, vector, payload in zip(ids, vectors, payloads)
]
client.upsert(
collection_name=collection_name,
points=points,
)
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
def search_vectors(
client: qdrant_client.QdrantClient,
collection_name: str,
query_vector: Vector,
filter_params: dict = None,
limit: int = 10,
) -> list[qdrant_models.ScoredPoint]:
"""Search for similar vectors in a collection.
Args:
client: Qdrant client
collection_name: Name of the collection
query_vector: Query vector
filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}})
limit: Maximum number of results to return
Returns:
List of scored points
"""
filter_obj = None
if filter_params:
filter_obj = qdrant_models.Filter(**filter_params)
return client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=filter_obj,
limit=limit,
)
def delete_vectors(
client: qdrant_client.QdrantClient,
collection_name: str,
ids: list[str],
) -> None:
"""
Delete vectors from a collection.
Args:
client: Qdrant client
collection_name: Name of the collection
ids: List of vector IDs to delete
"""
client.delete(
collection_name=collection_name,
points_selector=qdrant_models.PointIdsList(
points=ids,
),
)
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict:
"""
Get information about a collection.
Args:
client: Qdrant client
collection_name: Name of the collection
Returns:
Dictionary with collection information
"""
info = client.get_collection(collection_name)
return info.model_dump()

View File

@ -12,6 +12,7 @@ from memory.workers.email import (
imap_connection,
parse_email_message,
process_folder,
vectorize_email,
)
@ -63,11 +64,15 @@ def process_message(
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
create_mail_message(db, source_item.id, parsed_email, folder)
mail_message = create_mail_message(db, source_item.id, parsed_email, folder)
source_item.vector_ids = vectorize_email(mail_message)
db.commit()
# TODO: Queue for embedding once that's implemented
logger.info(f"Stored embedding for message {parsed_email['message_id']}")
logger.info("Vector IDs:")
for vector_id in source_item.vector_ids:
logger.info(f" - {vector_id}")
return source_item.id

View File

@ -6,11 +6,14 @@ import uuid
from pathlib import Path
import pytest
import qdrant_client
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from testcontainers.qdrant import QdrantContainer
from memory.common import settings
from tests.providers.email_provider import MockEmailProvider
from memory.workers.qdrant import initialize_collections
def get_test_db_name() -> str:
@ -197,3 +200,13 @@ def email_provider():
def mock_file_storage(tmp_path: Path):
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
yield
@pytest.fixture
def qdrant():
with QdrantContainer() as qdrant:
client = qdrant.get_client()
with patch.object(qdrant_client, "QdrantClient", return_value=client):
initialize_collections(client)
yield client

View File

@ -56,7 +56,7 @@ def test_email_account(db_session):
return account
def test_process_simple_email(db_session, test_email_account):
def test_process_simple_email(db_session, test_email_account, qdrant):
"""Test processing a simple email message."""
source_id = process_message(
account_id=test_email_account.id,
@ -85,7 +85,7 @@ def test_process_simple_email(db_session, test_email_account):
assert mail_message.folder == "INBOX"
def test_process_email_with_attachment(db_session, test_email_account):
def test_process_email_with_attachment(db_session, test_email_account, qdrant):
"""Test processing a message with an attachment."""
source_id = process_message(
account_id=test_email_account.id,
@ -116,7 +116,7 @@ def test_process_email_with_attachment(db_session, test_email_account):
assert attachments[0].content is not None or attachments[0].file_path is not None
def test_process_empty_message(db_session, test_email_account):
def test_process_empty_message(db_session, test_email_account, qdrant):
"""Test processing an empty/invalid message."""
source_id = process_message(
account_id=test_email_account.id,
@ -128,7 +128,7 @@ def test_process_empty_message(db_session, test_email_account):
assert source_id is None
def test_process_duplicate_message(db_session, test_email_account):
def test_process_duplicate_message(db_session, test_email_account, qdrant):
"""Test that duplicate messages are detected and not stored again."""
# First call should succeed and create records
source_id_1 = process_message(

View File

@ -27,9 +27,21 @@ from memory.workers.email import (
process_folder,
process_attachment,
process_attachments,
vectorize_email,
)
@pytest.fixture
def mock_uuid4():
i = 0
def uuid4():
nonlocal i
i += 1
return f"00000000-0000-0000-0000-00000000000{i}"
with patch("uuid.uuid4", side_effect=uuid4):
yield
# Use a simple counter to generate unique message IDs without calling make_msgid
_msg_id_counter = 0
@ -655,3 +667,102 @@ def test_process_folder_error(email_provider):
email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor
)
assert result == {"messages_found": 0, "new_messages": 0, "errors": 0}
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
source_item = SourceItem(
id=1,
modality="mail",
sha256=b"test_hash" + bytes(24),
tags=["test"],
byte_length=100,
mime_type="message/rfc822",
embed_status="RAW",
)
db_session.add(source_item)
db_session.flush()
# Create mail message
mail_message = MailMessage(
id=1,
source_id=1,
message_id="<test-vector@example.com>",
subject="Test Vectorization",
sender="sender@example.com",
recipients=["recipient@example.com"],
body_raw="This is a test email for vectorization",
folder="INBOX",
)
db_session.add(mail_message)
db_session.flush()
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
vector_ids = vectorize_email(mail_message)
assert len(vector_ids) == 1
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
source_item = SourceItem(
id=2,
modality="mail",
sha256=b"test_hash2" + bytes(24),
tags=["test"],
byte_length=200,
mime_type="message/rfc822",
embed_status="RAW",
)
db_session.add(source_item)
db_session.flush()
# Create mail message
mail_message = MailMessage(
id=2,
source_id=2,
message_id="<test-vector-attach@example.com>",
subject="Test Vectorization with Attachments",
sender="sender@example.com",
recipients=["recipient@example.com"],
body_raw="This is a test email with attachments",
folder="INBOX",
)
db_session.add(mail_message)
db_session.flush()
# Add two attachments - one with content and one with file_path
attachment1 = EmailAttachment(
id=1,
mail_message_id=mail_message.id,
filename="inline.txt",
content_type="text/plain",
size=100,
content=base64.b64encode(b"This is inline content"),
file_path=None,
)
attachment2 = EmailAttachment(
id=2,
mail_message_id=mail_message.id,
filename="stored.txt",
content_type="text/plain",
size=200,
content=None,
file_path="/path/to/stored.txt",
)
db_session.add_all([attachment1, attachment2])
db_session.flush()
# Mock embedding functions but use real qdrant
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536), \
patch("memory.common.embedding.embed_file", return_value=[0.7] * 1536):
# Call the function
vector_ids = vectorize_email(mail_message)
# Verify results
assert len(vector_ids) == 3
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
assert vector_ids[1] == "00000000-0000-0000-0000-000000000002"
assert vector_ids[2] == "00000000-0000-0000-0000-000000000003"

View File

@ -0,0 +1,82 @@
import pytest
from unittest.mock import MagicMock, patch
import qdrant_client
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse
from memory.workers.qdrant import (
DEFAULT_COLLECTIONS,
ensure_collection_exists,
initialize_collections,
upsert_vectors,
search_vectors,
delete_vectors,
)
@pytest.fixture
def mock_qdrant_client():
with patch.object(qdrant_client, "QdrantClient", return_value=MagicMock()) as mock_client:
yield mock_client
def test_ensure_collection_exists_existing(mock_qdrant_client):
mock_qdrant_client.get_collection.return_value = {}
assert not ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
mock_qdrant_client.create_collection.assert_not_called()
def test_ensure_collection_exists_new(mock_qdrant_client):
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
status_code=404, reason_phrase='asd', content=b'asd', headers=None
)
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
mock_qdrant_client.create_collection.assert_called_once()
mock_qdrant_client.create_payload_index.assert_called_once()
def test_initialize_collections(mock_qdrant_client):
initialize_collections(mock_qdrant_client)
assert mock_qdrant_client.get_collection.call_count == len(DEFAULT_COLLECTIONS)
def test_upsert_vectors(mock_qdrant_client):
ids = ["1", "2"]
vectors = [[0.1, 0.2], [0.3, 0.4]]
payloads = [{"tag": "test1"}, {"tag": "test2"}]
upsert_vectors(mock_qdrant_client, "test_collection", ids, vectors, payloads)
mock_qdrant_client.upsert.assert_called_once()
args, kwargs = mock_qdrant_client.upsert.call_args
assert kwargs["collection_name"] == "test_collection"
assert len(kwargs["points"]) == 2
# Check points were created correctly
points = kwargs["points"]
assert points[0].id == "1"
assert points[0].vector == [0.1, 0.2]
assert points[0].payload == {"tag": "test1"}
assert points[1].id == "2"
assert points[1].vector == [0.3, 0.4]
assert points[1].payload == {"tag": "test2"}
def test_delete_vectors(mock_qdrant_client):
ids = ["1", "2"]
delete_vectors(mock_qdrant_client, "test_collection", ids)
mock_qdrant_client.delete.assert_called_once()
args, kwargs = mock_qdrant_client.delete.call_args
assert kwargs["collection_name"] == "test_collection"
assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList)
assert kwargs["points_selector"].points == ids