mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 21:34:42 +02:00
add qdrant
This commit is contained in:
parent
889df318a1
commit
a104a3211b
@ -3,3 +3,4 @@ pytest-cov==4.1.0
|
|||||||
black==23.12.1
|
black==23.12.1
|
||||||
mypy==1.8.0
|
mypy==1.8.0
|
||||||
isort==5.13.2
|
isort==5.13.2
|
||||||
|
testcontainers[qdrant]==4.10.0
|
@ -29,6 +29,8 @@ class SourceItem(Base):
|
|||||||
byte_length = Column(Integer)
|
byte_length = Column(Integer)
|
||||||
mime_type = Column(Text)
|
mime_type = Column(Text)
|
||||||
|
|
||||||
|
mail_message = relationship("MailMessage", back_populates="source", uselist=False)
|
||||||
|
|
||||||
# Add table-level constraint and indexes
|
# Add table-level constraint and indexes
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||||
@ -53,11 +55,24 @@ class MailMessage(Base):
|
|||||||
tsv = Column(TSVECTOR)
|
tsv = Column(TSVECTOR)
|
||||||
|
|
||||||
attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan")
|
attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan")
|
||||||
|
source = relationship("SourceItem", back_populates="mail_message")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attachments_path(self) -> Path:
|
def attachments_path(self) -> Path:
|
||||||
return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX')
|
return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX')
|
||||||
|
|
||||||
|
def as_payload(self) -> dict:
|
||||||
|
return {
|
||||||
|
"source_id": self.source_id,
|
||||||
|
"message_id": self.message_id,
|
||||||
|
"subject": self.subject,
|
||||||
|
"sender": self.sender,
|
||||||
|
"recipients": self.recipients,
|
||||||
|
"folder": self.folder,
|
||||||
|
"tags": self.source.tags,
|
||||||
|
"date": self.sent_at and self.sent_at.isoformat() or None,
|
||||||
|
}
|
||||||
|
|
||||||
# Add indexes
|
# Add indexes
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('mail_sent_idx', 'sent_at'),
|
Index('mail_sent_idx', 'sent_at'),
|
||||||
@ -80,6 +95,17 @@ class EmailAttachment(Base):
|
|||||||
|
|
||||||
mail_message = relationship("MailMessage", back_populates="attachments")
|
mail_message = relationship("MailMessage", back_populates="attachments")
|
||||||
|
|
||||||
|
def as_payload(self) -> dict:
|
||||||
|
return {
|
||||||
|
"filename": self.filename,
|
||||||
|
"content_type": self.content_type,
|
||||||
|
"size": self.size,
|
||||||
|
"created_at": self.created_at and self.created_at.isoformat() or None,
|
||||||
|
"mail_message_id": self.mail_message_id,
|
||||||
|
"source_id": self.mail_message.source_id,
|
||||||
|
"tags": self.mail_message.source.tags,
|
||||||
|
}
|
||||||
|
|
||||||
# Add indexes
|
# Add indexes
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index('email_attachment_message_idx', 'mail_message_id'),
|
Index('email_attachment_message_idx', 'mail_message_id'),
|
||||||
|
12
src/memory/common/embedding.py
Normal file
12
src/memory/common/embedding.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
|
||||||
|
"""
|
||||||
|
Embed a text using OpenAI's API.
|
||||||
|
"""
|
||||||
|
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
||||||
|
|
||||||
|
|
||||||
|
def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
|
||||||
|
"""
|
||||||
|
Embed a file using OpenAI's API.
|
||||||
|
"""
|
||||||
|
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
@ -4,6 +4,9 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
def boolean_env(key: str, default: bool = False) -> bool:
|
||||||
|
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
||||||
|
|
||||||
|
|
||||||
DB_USER = os.getenv("DB_USER", "kb")
|
DB_USER = os.getenv("DB_USER", "kb")
|
||||||
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
|
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
|
||||||
@ -22,3 +25,11 @@ FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
|||||||
|
|
||||||
# Maximum attachment size to store directly in the database (10MB)
|
# Maximum attachment size to store directly in the database (10MB)
|
||||||
MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024))
|
MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024))
|
||||||
|
|
||||||
|
# Qdrant settings
|
||||||
|
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||||||
|
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||||||
|
QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334"))
|
||||||
|
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
|
||||||
|
QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False)
|
||||||
|
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
@ -13,7 +13,8 @@ import pathlib
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||||
from memory.common import settings
|
from memory.common import settings, embedding
|
||||||
|
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -456,3 +457,32 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
|||||||
conn.logout()
|
conn.logout()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
|
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def vectorize_email(email: MailMessage) -> list[float]:
|
||||||
|
qdrant_client = get_qdrant_client()
|
||||||
|
|
||||||
|
vector_ids = [str(uuid.uuid4())]
|
||||||
|
vectors = [embedding.embed_text(email.body_raw)]
|
||||||
|
payloads = [email.as_payload()]
|
||||||
|
|
||||||
|
for attachment in email.attachments:
|
||||||
|
vector_ids.append(str(uuid.uuid4()))
|
||||||
|
payloads.append(attachment.as_payload())
|
||||||
|
|
||||||
|
if attachment.file_path:
|
||||||
|
vector = embedding.embed_file(attachment.file_path)
|
||||||
|
else:
|
||||||
|
vector = embedding.embed_text(attachment.content)
|
||||||
|
vectors.append(vector)
|
||||||
|
|
||||||
|
upsert_vectors(
|
||||||
|
client=qdrant_client,
|
||||||
|
collection_name="mail",
|
||||||
|
ids=vector_ids,
|
||||||
|
vectors=vectors,
|
||||||
|
payloads=payloads,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Stored embedding for message {email.message_id}")
|
||||||
|
return vector_ids
|
||||||
|
231
src/memory/workers/qdrant.py
Normal file
231
src/memory/workers/qdrant.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Literal, TypedDict, cast
|
||||||
|
|
||||||
|
import qdrant_client
|
||||||
|
from qdrant_client.http import models as qdrant_models
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type of distance metric
|
||||||
|
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||||
|
Vector = list[float]
|
||||||
|
|
||||||
|
class Collection(TypedDict):
|
||||||
|
dimension: int
|
||||||
|
distance: DistanceType
|
||||||
|
on_disk: bool
|
||||||
|
shards: int
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
||||||
|
"mail": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
# "chat": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
# "git": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
# "photo": {"dimension": 512, "distance": "Cosine"},
|
||||||
|
# "book": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
# "blog": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
# "doc": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_qdrant_client() -> qdrant_client.QdrantClient:
|
||||||
|
"""Create and return a Qdrant client using environment configuration."""
|
||||||
|
logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}")
|
||||||
|
|
||||||
|
return qdrant_client.QdrantClient(
|
||||||
|
host=settings.QDRANT_HOST,
|
||||||
|
port=settings.QDRANT_PORT,
|
||||||
|
grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None,
|
||||||
|
prefer_grpc=settings.QDRANT_PREFER_GRPC,
|
||||||
|
api_key=settings.QDRANT_API_KEY,
|
||||||
|
timeout=settings.QDRANT_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_collection_exists(
|
||||||
|
client: qdrant_client.QdrantClient,
|
||||||
|
collection_name: str,
|
||||||
|
dimension: int,
|
||||||
|
distance: DistanceType = "Cosine",
|
||||||
|
on_disk: bool = True,
|
||||||
|
shards: int = 1,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Ensure a collection exists with the specified parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Qdrant client
|
||||||
|
collection_name: Name of the collection
|
||||||
|
dimension: Vector dimension
|
||||||
|
distance: Distance metric (Cosine, Dot, Euclidean)
|
||||||
|
on_disk: Whether to store vectors on disk
|
||||||
|
shards: Number of shards for the collection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the collection was created, False if it already existed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client.get_collection(collection_name)
|
||||||
|
logger.debug(f"Collection {collection_name} already exists")
|
||||||
|
return False
|
||||||
|
except (UnexpectedResponse, ValueError):
|
||||||
|
logger.info(f"Creating collection {collection_name} with dimension {dimension}")
|
||||||
|
client.create_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=qdrant_models.VectorParams(
|
||||||
|
size=dimension,
|
||||||
|
distance=cast(qdrant_models.Distance, distance),
|
||||||
|
),
|
||||||
|
on_disk_payload=on_disk,
|
||||||
|
shard_number=shards,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create common payload indexes
|
||||||
|
client.create_payload_index(
|
||||||
|
collection_name=collection_name,
|
||||||
|
field_name="tags",
|
||||||
|
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize all required collections in Qdrant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Qdrant client
|
||||||
|
collections: Dictionary mapping collection names to their parameters.
|
||||||
|
If None, defaults to the DEFAULT_COLLECTIONS.
|
||||||
|
"""
|
||||||
|
if collections is None:
|
||||||
|
collections = DEFAULT_COLLECTIONS
|
||||||
|
|
||||||
|
for name, params in collections.items():
|
||||||
|
ensure_collection_exists(
|
||||||
|
client,
|
||||||
|
collection_name=name,
|
||||||
|
dimension=params["dimension"],
|
||||||
|
distance=params.get("distance", "Cosine"),
|
||||||
|
on_disk=params.get("on_disk", True),
|
||||||
|
shards=params.get("shards", 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_qdrant() -> qdrant_client.QdrantClient:
|
||||||
|
"""Get a Qdrant client and initialize collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured Qdrant client
|
||||||
|
"""
|
||||||
|
client = get_qdrant_client()
|
||||||
|
initialize_collections(client)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_vectors(
|
||||||
|
client: qdrant_client.QdrantClient,
|
||||||
|
collection_name: str,
|
||||||
|
ids: list[str],
|
||||||
|
vectors: list[Vector],
|
||||||
|
payloads: list[dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Upsert vectors into a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Qdrant client
|
||||||
|
collection_name: Name of the collection
|
||||||
|
ids: List of vector IDs (as strings)
|
||||||
|
vectors: List of vectors
|
||||||
|
payloads: List of payloads, one per vector
|
||||||
|
"""
|
||||||
|
if payloads is None:
|
||||||
|
payloads = [{} for _ in ids]
|
||||||
|
|
||||||
|
points = [
|
||||||
|
qdrant_models.PointStruct(
|
||||||
|
id=id_str,
|
||||||
|
vector=vector,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
for id_str, vector, payload in zip(ids, vectors, payloads)
|
||||||
|
]
|
||||||
|
|
||||||
|
client.upsert(
|
||||||
|
collection_name=collection_name,
|
||||||
|
points=points,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def search_vectors(
|
||||||
|
client: qdrant_client.QdrantClient,
|
||||||
|
collection_name: str,
|
||||||
|
query_vector: Vector,
|
||||||
|
filter_params: dict = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[qdrant_models.ScoredPoint]:
|
||||||
|
"""Search for similar vectors in a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Qdrant client
|
||||||
|
collection_name: Name of the collection
|
||||||
|
query_vector: Query vector
|
||||||
|
filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}})
|
||||||
|
limit: Maximum number of results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scored points
|
||||||
|
"""
|
||||||
|
filter_obj = None
|
||||||
|
if filter_params:
|
||||||
|
filter_obj = qdrant_models.Filter(**filter_params)
|
||||||
|
|
||||||
|
return client.search(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=filter_obj,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_vectors(
|
||||||
|
client: qdrant_client.QdrantClient,
|
||||||
|
collection_name: str,
|
||||||
|
ids: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Delete vectors from a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Qdrant client
|
||||||
|
collection_name: Name of the collection
|
||||||
|
ids: List of vector IDs to delete
|
||||||
|
"""
|
||||||
|
client.delete(
|
||||||
|
collection_name=collection_name,
|
||||||
|
points_selector=qdrant_models.PointIdsList(
|
||||||
|
points=ids,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict:
|
||||||
|
"""
|
||||||
|
Get information about a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Qdrant client
|
||||||
|
collection_name: Name of the collection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with collection information
|
||||||
|
"""
|
||||||
|
info = client.get_collection(collection_name)
|
||||||
|
return info.model_dump()
|
@ -12,6 +12,7 @@ from memory.workers.email import (
|
|||||||
imap_connection,
|
imap_connection,
|
||||||
parse_email_message,
|
parse_email_message,
|
||||||
process_folder,
|
process_folder,
|
||||||
|
vectorize_email,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -63,11 +64,15 @@ def process_message(
|
|||||||
|
|
||||||
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
|
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
|
||||||
|
|
||||||
create_mail_message(db, source_item.id, parsed_email, folder)
|
mail_message = create_mail_message(db, source_item.id, parsed_email, folder)
|
||||||
|
|
||||||
|
source_item.vector_ids = vectorize_email(mail_message)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# TODO: Queue for embedding once that's implemented
|
logger.info(f"Stored embedding for message {parsed_email['message_id']}")
|
||||||
|
logger.info("Vector IDs:")
|
||||||
|
for vector_id in source_item.vector_ids:
|
||||||
|
logger.info(f" - {vector_id}")
|
||||||
|
|
||||||
return source_item.id
|
return source_item.id
|
||||||
|
|
||||||
|
@ -6,11 +6,14 @@ import uuid
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import qdrant_client
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from testcontainers.qdrant import QdrantContainer
|
||||||
|
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from tests.providers.email_provider import MockEmailProvider
|
from tests.providers.email_provider import MockEmailProvider
|
||||||
|
from memory.workers.qdrant import initialize_collections
|
||||||
|
|
||||||
|
|
||||||
def get_test_db_name() -> str:
|
def get_test_db_name() -> str:
|
||||||
@ -197,3 +200,13 @@ def email_provider():
|
|||||||
def mock_file_storage(tmp_path: Path):
|
def mock_file_storage(tmp_path: Path):
|
||||||
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
|
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def qdrant():
|
||||||
|
with QdrantContainer() as qdrant:
|
||||||
|
client = qdrant.get_client()
|
||||||
|
with patch.object(qdrant_client, "QdrantClient", return_value=client):
|
||||||
|
initialize_collections(client)
|
||||||
|
yield client
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ def test_email_account(db_session):
|
|||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
def test_process_simple_email(db_session, test_email_account):
|
def test_process_simple_email(db_session, test_email_account, qdrant):
|
||||||
"""Test processing a simple email message."""
|
"""Test processing a simple email message."""
|
||||||
source_id = process_message(
|
source_id = process_message(
|
||||||
account_id=test_email_account.id,
|
account_id=test_email_account.id,
|
||||||
@ -85,7 +85,7 @@ def test_process_simple_email(db_session, test_email_account):
|
|||||||
assert mail_message.folder == "INBOX"
|
assert mail_message.folder == "INBOX"
|
||||||
|
|
||||||
|
|
||||||
def test_process_email_with_attachment(db_session, test_email_account):
|
def test_process_email_with_attachment(db_session, test_email_account, qdrant):
|
||||||
"""Test processing a message with an attachment."""
|
"""Test processing a message with an attachment."""
|
||||||
source_id = process_message(
|
source_id = process_message(
|
||||||
account_id=test_email_account.id,
|
account_id=test_email_account.id,
|
||||||
@ -116,7 +116,7 @@ def test_process_email_with_attachment(db_session, test_email_account):
|
|||||||
assert attachments[0].content is not None or attachments[0].file_path is not None
|
assert attachments[0].content is not None or attachments[0].file_path is not None
|
||||||
|
|
||||||
|
|
||||||
def test_process_empty_message(db_session, test_email_account):
|
def test_process_empty_message(db_session, test_email_account, qdrant):
|
||||||
"""Test processing an empty/invalid message."""
|
"""Test processing an empty/invalid message."""
|
||||||
source_id = process_message(
|
source_id = process_message(
|
||||||
account_id=test_email_account.id,
|
account_id=test_email_account.id,
|
||||||
@ -128,7 +128,7 @@ def test_process_empty_message(db_session, test_email_account):
|
|||||||
assert source_id is None
|
assert source_id is None
|
||||||
|
|
||||||
|
|
||||||
def test_process_duplicate_message(db_session, test_email_account):
|
def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
||||||
"""Test that duplicate messages are detected and not stored again."""
|
"""Test that duplicate messages are detected and not stored again."""
|
||||||
# First call should succeed and create records
|
# First call should succeed and create records
|
||||||
source_id_1 = process_message(
|
source_id_1 = process_message(
|
||||||
|
@ -27,9 +27,21 @@ from memory.workers.email import (
|
|||||||
process_folder,
|
process_folder,
|
||||||
process_attachment,
|
process_attachment,
|
||||||
process_attachments,
|
process_attachments,
|
||||||
|
vectorize_email,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_uuid4():
|
||||||
|
i = 0
|
||||||
|
def uuid4():
|
||||||
|
nonlocal i
|
||||||
|
i += 1
|
||||||
|
return f"00000000-0000-0000-0000-00000000000{i}"
|
||||||
|
|
||||||
|
with patch("uuid.uuid4", side_effect=uuid4):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
# Use a simple counter to generate unique message IDs without calling make_msgid
|
# Use a simple counter to generate unique message IDs without calling make_msgid
|
||||||
_msg_id_counter = 0
|
_msg_id_counter = 0
|
||||||
@ -655,3 +667,102 @@ def test_process_folder_error(email_provider):
|
|||||||
email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor
|
email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor
|
||||||
)
|
)
|
||||||
assert result == {"messages_found": 0, "new_messages": 0, "errors": 0}
|
assert result == {"messages_found": 0, "new_messages": 0, "errors": 0}
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||||
|
source_item = SourceItem(
|
||||||
|
id=1,
|
||||||
|
modality="mail",
|
||||||
|
sha256=b"test_hash" + bytes(24),
|
||||||
|
tags=["test"],
|
||||||
|
byte_length=100,
|
||||||
|
mime_type="message/rfc822",
|
||||||
|
embed_status="RAW",
|
||||||
|
)
|
||||||
|
db_session.add(source_item)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Create mail message
|
||||||
|
mail_message = MailMessage(
|
||||||
|
id=1,
|
||||||
|
source_id=1,
|
||||||
|
message_id="<test-vector@example.com>",
|
||||||
|
subject="Test Vectorization",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient@example.com"],
|
||||||
|
body_raw="This is a test email for vectorization",
|
||||||
|
folder="INBOX",
|
||||||
|
)
|
||||||
|
db_session.add(mail_message)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
|
||||||
|
vector_ids = vectorize_email(mail_message)
|
||||||
|
|
||||||
|
assert len(vector_ids) == 1
|
||||||
|
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||||
|
source_item = SourceItem(
|
||||||
|
id=2,
|
||||||
|
modality="mail",
|
||||||
|
sha256=b"test_hash2" + bytes(24),
|
||||||
|
tags=["test"],
|
||||||
|
byte_length=200,
|
||||||
|
mime_type="message/rfc822",
|
||||||
|
embed_status="RAW",
|
||||||
|
)
|
||||||
|
db_session.add(source_item)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Create mail message
|
||||||
|
mail_message = MailMessage(
|
||||||
|
id=2,
|
||||||
|
source_id=2,
|
||||||
|
message_id="<test-vector-attach@example.com>",
|
||||||
|
subject="Test Vectorization with Attachments",
|
||||||
|
sender="sender@example.com",
|
||||||
|
recipients=["recipient@example.com"],
|
||||||
|
body_raw="This is a test email with attachments",
|
||||||
|
folder="INBOX",
|
||||||
|
)
|
||||||
|
db_session.add(mail_message)
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Add two attachments - one with content and one with file_path
|
||||||
|
attachment1 = EmailAttachment(
|
||||||
|
id=1,
|
||||||
|
mail_message_id=mail_message.id,
|
||||||
|
filename="inline.txt",
|
||||||
|
content_type="text/plain",
|
||||||
|
size=100,
|
||||||
|
content=base64.b64encode(b"This is inline content"),
|
||||||
|
file_path=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
attachment2 = EmailAttachment(
|
||||||
|
id=2,
|
||||||
|
mail_message_id=mail_message.id,
|
||||||
|
filename="stored.txt",
|
||||||
|
content_type="text/plain",
|
||||||
|
size=200,
|
||||||
|
content=None,
|
||||||
|
file_path="/path/to/stored.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all([attachment1, attachment2])
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Mock embedding functions but use real qdrant
|
||||||
|
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536), \
|
||||||
|
patch("memory.common.embedding.embed_file", return_value=[0.7] * 1536):
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
vector_ids = vectorize_email(mail_message)
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert len(vector_ids) == 3
|
||||||
|
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
|
||||||
|
assert vector_ids[1] == "00000000-0000-0000-0000-000000000002"
|
||||||
|
assert vector_ids[2] == "00000000-0000-0000-0000-000000000003"
|
||||||
|
82
tests/memory/workers/test_qdrant.py
Normal file
82
tests/memory/workers/test_qdrant.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import qdrant_client
|
||||||
|
from qdrant_client.http import models as qdrant_models
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
|
from memory.workers.qdrant import (
|
||||||
|
DEFAULT_COLLECTIONS,
|
||||||
|
ensure_collection_exists,
|
||||||
|
initialize_collections,
|
||||||
|
upsert_vectors,
|
||||||
|
search_vectors,
|
||||||
|
delete_vectors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_qdrant_client():
|
||||||
|
with patch.object(qdrant_client, "QdrantClient", return_value=MagicMock()) as mock_client:
|
||||||
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_collection_exists_existing(mock_qdrant_client):
|
||||||
|
mock_qdrant_client.get_collection.return_value = {}
|
||||||
|
assert not ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
||||||
|
|
||||||
|
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||||
|
mock_qdrant_client.create_collection.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_collection_exists_new(mock_qdrant_client):
|
||||||
|
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
|
||||||
|
status_code=404, reason_phrase='asd', content=b'asd', headers=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
||||||
|
|
||||||
|
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||||
|
mock_qdrant_client.create_collection.assert_called_once()
|
||||||
|
mock_qdrant_client.create_payload_index.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_collections(mock_qdrant_client):
|
||||||
|
initialize_collections(mock_qdrant_client)
|
||||||
|
|
||||||
|
assert mock_qdrant_client.get_collection.call_count == len(DEFAULT_COLLECTIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_vectors(mock_qdrant_client):
|
||||||
|
ids = ["1", "2"]
|
||||||
|
vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||||
|
payloads = [{"tag": "test1"}, {"tag": "test2"}]
|
||||||
|
|
||||||
|
upsert_vectors(mock_qdrant_client, "test_collection", ids, vectors, payloads)
|
||||||
|
|
||||||
|
mock_qdrant_client.upsert.assert_called_once()
|
||||||
|
args, kwargs = mock_qdrant_client.upsert.call_args
|
||||||
|
assert kwargs["collection_name"] == "test_collection"
|
||||||
|
assert len(kwargs["points"]) == 2
|
||||||
|
|
||||||
|
# Check points were created correctly
|
||||||
|
points = kwargs["points"]
|
||||||
|
assert points[0].id == "1"
|
||||||
|
assert points[0].vector == [0.1, 0.2]
|
||||||
|
assert points[0].payload == {"tag": "test1"}
|
||||||
|
assert points[1].id == "2"
|
||||||
|
assert points[1].vector == [0.3, 0.4]
|
||||||
|
assert points[1].payload == {"tag": "test2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_vectors(mock_qdrant_client):
|
||||||
|
ids = ["1", "2"]
|
||||||
|
|
||||||
|
delete_vectors(mock_qdrant_client, "test_collection", ids)
|
||||||
|
|
||||||
|
mock_qdrant_client.delete.assert_called_once()
|
||||||
|
args, kwargs = mock_qdrant_client.delete.call_args
|
||||||
|
|
||||||
|
assert kwargs["collection_name"] == "test_collection"
|
||||||
|
assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList)
|
||||||
|
assert kwargs["points_selector"].points == ids
|
Loading…
x
Reference in New Issue
Block a user