better vectorization

This commit is contained in:
Daniel O'Connell 2025-04-27 23:01:13 +02:00
parent a104a3211b
commit 3dca666d08
4 changed files with 86 additions and 49 deletions

View File

@ -1,3 +1,45 @@
import pathlib
from typing import Literal, TypedDict
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
on_disk: bool
shards: int
DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1536, "distance": "Cosine"},
"chat": {"dimension": 1536, "distance": "Cosine"},
"git": {"dimension": 1536, "distance": "Cosine"},
"photo": {"dimension": 512, "distance": "Cosine"},
"book": {"dimension": 1536, "distance": "Cosine"},
"blog": {"dimension": 1536, "distance": "Cosine"},
"doc": {"dimension": 1536, "distance": "Cosine"},
}
TYPES = {
"doc": ["text/*"],
"photo": ["image/*"],
"book": ["application/pdf", "application/epub+zip", "application/mobi", "application/x-mobipocket-ebook"],
}
def get_type(mime_type: str) -> str:
for type, mime_types in TYPES.items():
if mime_type in mime_types:
return type
stem = mime_type.split("/")[0]
for type, mime_types in TYPES.items():
if any(mime_type.startswith(stem) for mime_type in mime_types):
return type
return "unknown"
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
"""
Embed a text using OpenAI's API.
@ -10,3 +52,9 @@ def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensio
Embed a file using OpenAI's API.
"""
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
def embed(mime_type: str, content: bytes | str | pathlib.Path, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> tuple[str, list[float]]:
collection = get_type(mime_type)
return collection, [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector

View File

@ -11,7 +11,7 @@ from email.utils import parsedate_to_datetime
from typing import Generator, Callable, TypedDict, Literal
import pathlib
from sqlalchemy.orm import Session
from collections import defaultdict
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common import settings, embedding
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
@ -462,27 +462,35 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
def vectorize_email(email: MailMessage) -> list[float]:
qdrant_client = get_qdrant_client()
vector_ids = [str(uuid.uuid4())]
vectors = [embedding.embed_text(email.body_raw)]
payloads = [email.as_payload()]
for attachment in email.attachments:
vector_ids.append(str(uuid.uuid4()))
payloads.append(attachment.as_payload())
if attachment.file_path:
vector = embedding.embed_file(attachment.file_path)
else:
vector = embedding.embed_text(attachment.content)
vectors.append(vector)
vector_id = uuid.uuid4()
upsert_vectors(
client=qdrant_client,
collection_name="mail",
ids=vector_ids,
vectors=vectors,
payloads=payloads,
ids=[str(vector_id)],
vectors=[embedding.embed_text(email.body_raw)],
payloads=[email.as_payload()],
)
vector_ids = [f"mail/{vector_id}"]
embeds = defaultdict(list)
for attachment in email.attachments:
if attachment.file_path:
content = pathlib.Path(attachment.file_path).read_bytes()
else:
content = attachment.content
collection, vector = embedding.embed(attachment.content_type, content)
embeds[collection].append((str(uuid.uuid4()), vector, attachment.as_payload()))
for collection, embeds in embeds.items():
ids, vectors, payloads = zip(*embeds)
upsert_vectors(
client=qdrant_client,
collection_name=collection,
ids=ids,
vectors=vectors,
payloads=payloads,
)
vector_ids.extend([f"{collection}/{vector_id}" for vector_id in ids])
logger.info(f"Stored embedding for message {email.message_id}")
return vector_ids

View File

@ -1,34 +1,14 @@
import logging
from typing import Any, Literal, TypedDict, cast
from typing import Any, cast
import qdrant_client
from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse
from memory.common import settings
from memory.common.embedding import Collection, DEFAULT_COLLECTIONS, DistanceType, Vector
logger = logging.getLogger(__name__)
# Type of distance metric
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
on_disk: bool
shards: int
DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1536, "distance": "Cosine"},
# "chat": {"dimension": 1536, "distance": "Cosine"},
# "git": {"dimension": 1536, "distance": "Cosine"},
# "photo": {"dimension": 512, "distance": "Cosine"},
# "book": {"dimension": 1536, "distance": "Cosine"},
# "blog": {"dimension": 1536, "distance": "Cosine"},
# "doc": {"dimension": 1536, "distance": "Cosine"},
}
def get_qdrant_client() -> qdrant_client.QdrantClient:
"""Create and return a Qdrant client using environment configuration."""

View File

@ -700,7 +700,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
vector_ids = vectorize_email(mail_message)
assert len(vector_ids) == 1
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
@ -741,6 +741,9 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
file_path=None,
)
file_path = mail_message.attachments_path / "stored.txt"
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_bytes(b"This is stored content")
attachment2 = EmailAttachment(
id=2,
mail_message_id=mail_message.id,
@ -748,21 +751,19 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
content_type="text/plain",
size=200,
content=None,
file_path="/path/to/stored.txt",
file_path=str(file_path),
)
db_session.add_all([attachment1, attachment2])
db_session.flush()
# Mock embedding functions but use real qdrant
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536), \
patch("memory.common.embedding.embed_file", return_value=[0.7] * 1536):
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
# Call the function
vector_ids = vectorize_email(mail_message)
# Verify results
assert len(vector_ids) == 3
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
assert vector_ids[1] == "00000000-0000-0000-0000-000000000002"
assert vector_ids[2] == "00000000-0000-0000-0000-000000000003"
assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
assert vector_ids[1] == "doc/00000000-0000-0000-0000-000000000002"
assert vector_ids[2] == "doc/00000000-0000-0000-0000-000000000003"