mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
add qdrant
This commit is contained in:
parent
889df318a1
commit
a104a3211b
@ -2,4 +2,5 @@ pytest==7.4.4
|
||||
pytest-cov==4.1.0
|
||||
black==23.12.1
|
||||
mypy==1.8.0
|
||||
isort==5.13.2
|
||||
isort==5.13.2
|
||||
testcontainers[qdrant]==4.10.0
|
@ -29,6 +29,8 @@ class SourceItem(Base):
|
||||
byte_length = Column(Integer)
|
||||
mime_type = Column(Text)
|
||||
|
||||
mail_message = relationship("MailMessage", back_populates="source", uselist=False)
|
||||
|
||||
# Add table-level constraint and indexes
|
||||
__table_args__ = (
|
||||
CheckConstraint("embed_status IN ('RAW','QUEUED','STORED','FAILED')"),
|
||||
@ -53,11 +55,24 @@ class MailMessage(Base):
|
||||
tsv = Column(TSVECTOR)
|
||||
|
||||
attachments = relationship("EmailAttachment", back_populates="mail_message", cascade="all, delete-orphan")
|
||||
source = relationship("SourceItem", back_populates="mail_message")
|
||||
|
||||
@property
|
||||
def attachments_path(self) -> Path:
|
||||
return Path(settings.FILE_STORAGE_DIR) / self.sender / (self.folder or 'INBOX')
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.source_id,
|
||||
"message_id": self.message_id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
"recipients": self.recipients,
|
||||
"folder": self.folder,
|
||||
"tags": self.source.tags,
|
||||
"date": self.sent_at and self.sent_at.isoformat() or None,
|
||||
}
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('mail_sent_idx', 'sent_at'),
|
||||
@ -80,6 +95,17 @@ class EmailAttachment(Base):
|
||||
|
||||
mail_message = relationship("MailMessage", back_populates="attachments")
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"filename": self.filename,
|
||||
"content_type": self.content_type,
|
||||
"size": self.size,
|
||||
"created_at": self.created_at and self.created_at.isoformat() or None,
|
||||
"mail_message_id": self.mail_message_id,
|
||||
"source_id": self.mail_message.source_id,
|
||||
"tags": self.mail_message.source.tags,
|
||||
}
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index('email_attachment_message_idx', 'mail_message_id'),
|
||||
|
12
src/memory/common/embedding.py
Normal file
12
src/memory/common/embedding.py
Normal file
@ -0,0 +1,12 @@
|
||||
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
|
||||
"""
|
||||
Embed a text using OpenAI's API.
|
||||
"""
|
||||
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
||||
|
||||
|
||||
def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
|
||||
"""
|
||||
Embed a file using OpenAI's API.
|
||||
"""
|
||||
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
@ -4,6 +4,9 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def boolean_env(key: str, default: bool = False) -> bool:
|
||||
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
DB_USER = os.getenv("DB_USER", "kb")
|
||||
DB_PASSWORD = os.getenv("DB_PASSWORD", "kb")
|
||||
@ -21,4 +24,12 @@ FILE_STORAGE_DIR = pathlib.Path(os.getenv("FILE_STORAGE_DIR", "/tmp/memory_files
|
||||
FILE_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Maximum attachment size to store directly in the database (10MB)
|
||||
MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024))
|
||||
MAX_INLINE_ATTACHMENT_SIZE = int(os.getenv("MAX_INLINE_ATTACHMENT_SIZE", 1 * 1024 * 1024))
|
||||
|
||||
# Qdrant settings
|
||||
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||||
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334"))
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
|
||||
QDRANT_PREFER_GRPC = boolean_env("QDRANT_PREFER_GRPC", False)
|
||||
QDRANT_TIMEOUT = int(os.getenv("QDRANT_TIMEOUT", "60"))
|
@ -13,7 +13,8 @@ import pathlib
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||
from memory.common import settings
|
||||
from memory.common import settings, embedding
|
||||
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -456,3 +457,32 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
||||
conn.logout()
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging out from {account.imap_server}: {str(e)}")
|
||||
|
||||
|
||||
def vectorize_email(email: MailMessage) -> list[float]:
|
||||
qdrant_client = get_qdrant_client()
|
||||
|
||||
vector_ids = [str(uuid.uuid4())]
|
||||
vectors = [embedding.embed_text(email.body_raw)]
|
||||
payloads = [email.as_payload()]
|
||||
|
||||
for attachment in email.attachments:
|
||||
vector_ids.append(str(uuid.uuid4()))
|
||||
payloads.append(attachment.as_payload())
|
||||
|
||||
if attachment.file_path:
|
||||
vector = embedding.embed_file(attachment.file_path)
|
||||
else:
|
||||
vector = embedding.embed_text(attachment.content)
|
||||
vectors.append(vector)
|
||||
|
||||
upsert_vectors(
|
||||
client=qdrant_client,
|
||||
collection_name="mail",
|
||||
ids=vector_ids,
|
||||
vectors=vectors,
|
||||
payloads=payloads,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for message {email.message_id}")
|
||||
return vector_ids
|
||||
|
231
src/memory/workers/qdrant.py
Normal file
231
src/memory/workers/qdrant.py
Normal file
@ -0,0 +1,231 @@
|
||||
import logging
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type of distance metric
|
||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||
Vector = list[float]
|
||||
|
||||
class Collection(TypedDict):
|
||||
dimension: int
|
||||
distance: DistanceType
|
||||
on_disk: bool
|
||||
shards: int
|
||||
|
||||
|
||||
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
||||
"mail": {"dimension": 1536, "distance": "Cosine"},
|
||||
# "chat": {"dimension": 1536, "distance": "Cosine"},
|
||||
# "git": {"dimension": 1536, "distance": "Cosine"},
|
||||
# "photo": {"dimension": 512, "distance": "Cosine"},
|
||||
# "book": {"dimension": 1536, "distance": "Cosine"},
|
||||
# "blog": {"dimension": 1536, "distance": "Cosine"},
|
||||
# "doc": {"dimension": 1536, "distance": "Cosine"},
|
||||
}
|
||||
|
||||
|
||||
def get_qdrant_client() -> qdrant_client.QdrantClient:
|
||||
"""Create and return a Qdrant client using environment configuration."""
|
||||
logger.info(f"Connecting to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}")
|
||||
|
||||
return qdrant_client.QdrantClient(
|
||||
host=settings.QDRANT_HOST,
|
||||
port=settings.QDRANT_PORT,
|
||||
grpc_port=settings.QDRANT_GRPC_PORT if settings.QDRANT_PREFER_GRPC else None,
|
||||
prefer_grpc=settings.QDRANT_PREFER_GRPC,
|
||||
api_key=settings.QDRANT_API_KEY,
|
||||
timeout=settings.QDRANT_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
def ensure_collection_exists(
|
||||
client: qdrant_client.QdrantClient,
|
||||
collection_name: str,
|
||||
dimension: int,
|
||||
distance: DistanceType = "Cosine",
|
||||
on_disk: bool = True,
|
||||
shards: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure a collection exists with the specified parameters.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
dimension: Vector dimension
|
||||
distance: Distance metric (Cosine, Dot, Euclidean)
|
||||
on_disk: Whether to store vectors on disk
|
||||
shards: Number of shards for the collection
|
||||
|
||||
Returns:
|
||||
True if the collection was created, False if it already existed
|
||||
"""
|
||||
try:
|
||||
client.get_collection(collection_name)
|
||||
logger.debug(f"Collection {collection_name} already exists")
|
||||
return False
|
||||
except (UnexpectedResponse, ValueError):
|
||||
logger.info(f"Creating collection {collection_name} with dimension {dimension}")
|
||||
client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=qdrant_models.VectorParams(
|
||||
size=dimension,
|
||||
distance=cast(qdrant_models.Distance, distance),
|
||||
),
|
||||
on_disk_payload=on_disk,
|
||||
shard_number=shards,
|
||||
)
|
||||
|
||||
# Create common payload indexes
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name="tags",
|
||||
field_schema=qdrant_models.PayloadSchemaType.KEYWORD,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def initialize_collections(client: qdrant_client.QdrantClient, collections: dict[str, Collection] = None) -> None:
|
||||
"""
|
||||
Initialize all required collections in Qdrant.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collections: Dictionary mapping collection names to their parameters.
|
||||
If None, defaults to the DEFAULT_COLLECTIONS.
|
||||
"""
|
||||
if collections is None:
|
||||
collections = DEFAULT_COLLECTIONS
|
||||
|
||||
for name, params in collections.items():
|
||||
ensure_collection_exists(
|
||||
client,
|
||||
collection_name=name,
|
||||
dimension=params["dimension"],
|
||||
distance=params.get("distance", "Cosine"),
|
||||
on_disk=params.get("on_disk", True),
|
||||
shards=params.get("shards", 1),
|
||||
)
|
||||
|
||||
|
||||
def setup_qdrant() -> qdrant_client.QdrantClient:
|
||||
"""Get a Qdrant client and initialize collections.
|
||||
|
||||
Returns:
|
||||
Configured Qdrant client
|
||||
"""
|
||||
client = get_qdrant_client()
|
||||
initialize_collections(client)
|
||||
return client
|
||||
|
||||
|
||||
def upsert_vectors(
|
||||
client: qdrant_client.QdrantClient,
|
||||
collection_name: str,
|
||||
ids: list[str],
|
||||
vectors: list[Vector],
|
||||
payloads: list[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Upsert vectors into a collection.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
ids: List of vector IDs (as strings)
|
||||
vectors: List of vectors
|
||||
payloads: List of payloads, one per vector
|
||||
"""
|
||||
if payloads is None:
|
||||
payloads = [{} for _ in ids]
|
||||
|
||||
points = [
|
||||
qdrant_models.PointStruct(
|
||||
id=id_str,
|
||||
vector=vector,
|
||||
payload=payload,
|
||||
)
|
||||
for id_str, vector, payload in zip(ids, vectors, payloads)
|
||||
]
|
||||
|
||||
client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=points,
|
||||
)
|
||||
|
||||
logger.debug(f"Upserted {len(ids)} vectors into {collection_name}")
|
||||
|
||||
|
||||
def search_vectors(
|
||||
client: qdrant_client.QdrantClient,
|
||||
collection_name: str,
|
||||
query_vector: Vector,
|
||||
filter_params: dict = None,
|
||||
limit: int = 10,
|
||||
) -> list[qdrant_models.ScoredPoint]:
|
||||
"""Search for similar vectors in a collection.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
query_vector: Query vector
|
||||
filter_params: Filter parameters to apply (e.g., {"tags": {"value": "work"}})
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of scored points
|
||||
"""
|
||||
filter_obj = None
|
||||
if filter_params:
|
||||
filter_obj = qdrant_models.Filter(**filter_params)
|
||||
|
||||
return client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter_obj,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
def delete_vectors(
|
||||
client: qdrant_client.QdrantClient,
|
||||
collection_name: str,
|
||||
ids: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Delete vectors from a collection.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
ids: List of vector IDs to delete
|
||||
"""
|
||||
client.delete(
|
||||
collection_name=collection_name,
|
||||
points_selector=qdrant_models.PointIdsList(
|
||||
points=ids,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(f"Deleted {len(ids)} vectors from {collection_name}")
|
||||
|
||||
|
||||
def get_collection_info(client: qdrant_client.QdrantClient, collection_name: str) -> dict:
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Name of the collection
|
||||
|
||||
Returns:
|
||||
Dictionary with collection information
|
||||
"""
|
||||
info = client.get_collection(collection_name)
|
||||
return info.model_dump()
|
@ -12,6 +12,7 @@ from memory.workers.email import (
|
||||
imap_connection,
|
||||
parse_email_message,
|
||||
process_folder,
|
||||
vectorize_email,
|
||||
)
|
||||
|
||||
|
||||
@ -63,11 +64,15 @@ def process_message(
|
||||
|
||||
source_item = create_source_item(db, message_hash, account.tags, len(raw_email))
|
||||
|
||||
create_mail_message(db, source_item.id, parsed_email, folder)
|
||||
mail_message = create_mail_message(db, source_item.id, parsed_email, folder)
|
||||
|
||||
source_item.vector_ids = vectorize_email(mail_message)
|
||||
db.commit()
|
||||
|
||||
# TODO: Queue for embedding once that's implemented
|
||||
|
||||
logger.info(f"Stored embedding for message {parsed_email['message_id']}")
|
||||
logger.info("Vector IDs:")
|
||||
for vector_id in source_item.vector_ids:
|
||||
logger.info(f" - {vector_id}")
|
||||
|
||||
return source_item.id
|
||||
|
||||
|
@ -6,11 +6,14 @@ import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import qdrant_client
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from testcontainers.qdrant import QdrantContainer
|
||||
|
||||
from memory.common import settings
|
||||
from tests.providers.email_provider import MockEmailProvider
|
||||
from memory.workers.qdrant import initialize_collections
|
||||
|
||||
|
||||
def get_test_db_name() -> str:
|
||||
@ -197,3 +200,13 @@ def email_provider():
|
||||
def mock_file_storage(tmp_path: Path):
|
||||
with patch("memory.common.settings.FILE_STORAGE_DIR", tmp_path):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant():
|
||||
with QdrantContainer() as qdrant:
|
||||
client = qdrant.get_client()
|
||||
with patch.object(qdrant_client, "QdrantClient", return_value=client):
|
||||
initialize_collections(client)
|
||||
yield client
|
||||
|
||||
|
@ -56,7 +56,7 @@ def test_email_account(db_session):
|
||||
return account
|
||||
|
||||
|
||||
def test_process_simple_email(db_session, test_email_account):
|
||||
def test_process_simple_email(db_session, test_email_account, qdrant):
|
||||
"""Test processing a simple email message."""
|
||||
source_id = process_message(
|
||||
account_id=test_email_account.id,
|
||||
@ -85,7 +85,7 @@ def test_process_simple_email(db_session, test_email_account):
|
||||
assert mail_message.folder == "INBOX"
|
||||
|
||||
|
||||
def test_process_email_with_attachment(db_session, test_email_account):
|
||||
def test_process_email_with_attachment(db_session, test_email_account, qdrant):
|
||||
"""Test processing a message with an attachment."""
|
||||
source_id = process_message(
|
||||
account_id=test_email_account.id,
|
||||
@ -116,7 +116,7 @@ def test_process_email_with_attachment(db_session, test_email_account):
|
||||
assert attachments[0].content is not None or attachments[0].file_path is not None
|
||||
|
||||
|
||||
def test_process_empty_message(db_session, test_email_account):
|
||||
def test_process_empty_message(db_session, test_email_account, qdrant):
|
||||
"""Test processing an empty/invalid message."""
|
||||
source_id = process_message(
|
||||
account_id=test_email_account.id,
|
||||
@ -128,7 +128,7 @@ def test_process_empty_message(db_session, test_email_account):
|
||||
assert source_id is None
|
||||
|
||||
|
||||
def test_process_duplicate_message(db_session, test_email_account):
|
||||
def test_process_duplicate_message(db_session, test_email_account, qdrant):
|
||||
"""Test that duplicate messages are detected and not stored again."""
|
||||
# First call should succeed and create records
|
||||
source_id_1 = process_message(
|
||||
|
@ -27,9 +27,21 @@ from memory.workers.email import (
|
||||
process_folder,
|
||||
process_attachment,
|
||||
process_attachments,
|
||||
vectorize_email,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uuid4():
|
||||
i = 0
|
||||
def uuid4():
|
||||
nonlocal i
|
||||
i += 1
|
||||
return f"00000000-0000-0000-0000-00000000000{i}"
|
||||
|
||||
with patch("uuid.uuid4", side_effect=uuid4):
|
||||
yield
|
||||
|
||||
|
||||
# Use a simple counter to generate unique message IDs without calling make_msgid
|
||||
_msg_id_counter = 0
|
||||
@ -655,3 +667,102 @@ def test_process_folder_error(email_provider):
|
||||
email_provider, "INBOX", account, datetime(1970, 1, 1), mock_processor
|
||||
)
|
||||
assert result == {"messages_found": 0, "new_messages": 0, "errors": 0}
|
||||
|
||||
|
||||
def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
||||
source_item = SourceItem(
|
||||
id=1,
|
||||
modality="mail",
|
||||
sha256=b"test_hash" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=100,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
)
|
||||
db_session.add(source_item)
|
||||
db_session.flush()
|
||||
|
||||
# Create mail message
|
||||
mail_message = MailMessage(
|
||||
id=1,
|
||||
source_id=1,
|
||||
message_id="<test-vector@example.com>",
|
||||
subject="Test Vectorization",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
body_raw="This is a test email for vectorization",
|
||||
folder="INBOX",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
|
||||
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
|
||||
vector_ids = vectorize_email(mail_message)
|
||||
|
||||
assert len(vector_ids) == 1
|
||||
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
|
||||
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||
source_item = SourceItem(
|
||||
id=2,
|
||||
modality="mail",
|
||||
sha256=b"test_hash2" + bytes(24),
|
||||
tags=["test"],
|
||||
byte_length=200,
|
||||
mime_type="message/rfc822",
|
||||
embed_status="RAW",
|
||||
)
|
||||
db_session.add(source_item)
|
||||
db_session.flush()
|
||||
|
||||
# Create mail message
|
||||
mail_message = MailMessage(
|
||||
id=2,
|
||||
source_id=2,
|
||||
message_id="<test-vector-attach@example.com>",
|
||||
subject="Test Vectorization with Attachments",
|
||||
sender="sender@example.com",
|
||||
recipients=["recipient@example.com"],
|
||||
body_raw="This is a test email with attachments",
|
||||
folder="INBOX",
|
||||
)
|
||||
db_session.add(mail_message)
|
||||
db_session.flush()
|
||||
|
||||
# Add two attachments - one with content and one with file_path
|
||||
attachment1 = EmailAttachment(
|
||||
id=1,
|
||||
mail_message_id=mail_message.id,
|
||||
filename="inline.txt",
|
||||
content_type="text/plain",
|
||||
size=100,
|
||||
content=base64.b64encode(b"This is inline content"),
|
||||
file_path=None,
|
||||
)
|
||||
|
||||
attachment2 = EmailAttachment(
|
||||
id=2,
|
||||
mail_message_id=mail_message.id,
|
||||
filename="stored.txt",
|
||||
content_type="text/plain",
|
||||
size=200,
|
||||
content=None,
|
||||
file_path="/path/to/stored.txt",
|
||||
)
|
||||
|
||||
db_session.add_all([attachment1, attachment2])
|
||||
db_session.flush()
|
||||
|
||||
# Mock embedding functions but use real qdrant
|
||||
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536), \
|
||||
patch("memory.common.embedding.embed_file", return_value=[0.7] * 1536):
|
||||
|
||||
# Call the function
|
||||
vector_ids = vectorize_email(mail_message)
|
||||
|
||||
# Verify results
|
||||
assert len(vector_ids) == 3
|
||||
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
|
||||
assert vector_ids[1] == "00000000-0000-0000-0000-000000000002"
|
||||
assert vector_ids[2] == "00000000-0000-0000-0000-000000000003"
|
||||
|
82
tests/memory/workers/test_qdrant.py
Normal file
82
tests/memory/workers/test_qdrant.py
Normal file
@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import qdrant_client
|
||||
from qdrant_client.http import models as qdrant_models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
from memory.workers.qdrant import (
|
||||
DEFAULT_COLLECTIONS,
|
||||
ensure_collection_exists,
|
||||
initialize_collections,
|
||||
upsert_vectors,
|
||||
search_vectors,
|
||||
delete_vectors,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
with patch.object(qdrant_client, "QdrantClient", return_value=MagicMock()) as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_ensure_collection_exists_existing(mock_qdrant_client):
|
||||
mock_qdrant_client.get_collection.return_value = {}
|
||||
assert not ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
||||
|
||||
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_collection_exists_new(mock_qdrant_client):
|
||||
mock_qdrant_client.get_collection.side_effect = UnexpectedResponse(
|
||||
status_code=404, reason_phrase='asd', content=b'asd', headers=None
|
||||
)
|
||||
|
||||
assert ensure_collection_exists(mock_qdrant_client, "test_collection", 128)
|
||||
|
||||
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_called_once()
|
||||
mock_qdrant_client.create_payload_index.assert_called_once()
|
||||
|
||||
|
||||
def test_initialize_collections(mock_qdrant_client):
|
||||
initialize_collections(mock_qdrant_client)
|
||||
|
||||
assert mock_qdrant_client.get_collection.call_count == len(DEFAULT_COLLECTIONS)
|
||||
|
||||
|
||||
def test_upsert_vectors(mock_qdrant_client):
|
||||
ids = ["1", "2"]
|
||||
vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
payloads = [{"tag": "test1"}, {"tag": "test2"}]
|
||||
|
||||
upsert_vectors(mock_qdrant_client, "test_collection", ids, vectors, payloads)
|
||||
|
||||
mock_qdrant_client.upsert.assert_called_once()
|
||||
args, kwargs = mock_qdrant_client.upsert.call_args
|
||||
assert kwargs["collection_name"] == "test_collection"
|
||||
assert len(kwargs["points"]) == 2
|
||||
|
||||
# Check points were created correctly
|
||||
points = kwargs["points"]
|
||||
assert points[0].id == "1"
|
||||
assert points[0].vector == [0.1, 0.2]
|
||||
assert points[0].payload == {"tag": "test1"}
|
||||
assert points[1].id == "2"
|
||||
assert points[1].vector == [0.3, 0.4]
|
||||
assert points[1].payload == {"tag": "test2"}
|
||||
|
||||
|
||||
def test_delete_vectors(mock_qdrant_client):
|
||||
ids = ["1", "2"]
|
||||
|
||||
delete_vectors(mock_qdrant_client, "test_collection", ids)
|
||||
|
||||
mock_qdrant_client.delete.assert_called_once()
|
||||
args, kwargs = mock_qdrant_client.delete.call_args
|
||||
|
||||
assert kwargs["collection_name"] == "test_collection"
|
||||
assert isinstance(kwargs["points_selector"], qdrant_models.PointIdsList)
|
||||
assert kwargs["points_selector"].points == ids
|
Loading…
x
Reference in New Issue
Block a user