better vectorization

This commit is contained in:
Daniel O'Connell 2025-04-27 23:01:13 +02:00
parent a104a3211b
commit 3dca666d08
4 changed files with 86 additions and 49 deletions

View File

@ -1,3 +1,45 @@
import pathlib
from typing import Literal, TypedDict
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
on_disk: bool
shards: int
DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1536, "distance": "Cosine"},
"chat": {"dimension": 1536, "distance": "Cosine"},
"git": {"dimension": 1536, "distance": "Cosine"},
"photo": {"dimension": 512, "distance": "Cosine"},
"book": {"dimension": 1536, "distance": "Cosine"},
"blog": {"dimension": 1536, "distance": "Cosine"},
"doc": {"dimension": 1536, "distance": "Cosine"},
}
TYPES = {
"doc": ["text/*"],
"photo": ["image/*"],
"book": ["application/pdf", "application/epub+zip", "application/mobi", "application/x-mobipocket-ebook"],
}
def get_type(mime_type: str) -> str:
for type, mime_types in TYPES.items():
if mime_type in mime_types:
return type
stem = mime_type.split("/")[0]
for type, mime_types in TYPES.items():
if any(mime_type.startswith(stem) for mime_type in mime_types):
return type
return "unknown"
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]: def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
""" """
Embed a text using OpenAI's API. Embed a text using OpenAI's API.
@ -10,3 +52,9 @@ def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensio
Embed a file using OpenAI's API. Embed a file using OpenAI's API.
""" """
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
def embed(mime_type: str, content: bytes | str | pathlib.Path, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> tuple[str, list[float]]:
collection = get_type(mime_type)
return collection, [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector

View File

@ -11,7 +11,7 @@ from email.utils import parsedate_to_datetime
from typing import Generator, Callable, TypedDict, Literal from typing import Generator, Callable, TypedDict, Literal
import pathlib import pathlib
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from collections import defaultdict
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
from memory.common import settings, embedding from memory.common import settings, embedding
from memory.workers.qdrant import get_qdrant_client, upsert_vectors from memory.workers.qdrant import get_qdrant_client, upsert_vectors
@ -462,27 +462,35 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
def vectorize_email(email: MailMessage) -> list[float]: def vectorize_email(email: MailMessage) -> list[float]:
qdrant_client = get_qdrant_client() qdrant_client = get_qdrant_client()
vector_ids = [str(uuid.uuid4())] vector_id = uuid.uuid4()
vectors = [embedding.embed_text(email.body_raw)]
payloads = [email.as_payload()]
for attachment in email.attachments:
vector_ids.append(str(uuid.uuid4()))
payloads.append(attachment.as_payload())
if attachment.file_path:
vector = embedding.embed_file(attachment.file_path)
else:
vector = embedding.embed_text(attachment.content)
vectors.append(vector)
upsert_vectors( upsert_vectors(
client=qdrant_client, client=qdrant_client,
collection_name="mail", collection_name="mail",
ids=vector_ids, ids=[str(vector_id)],
vectors=vectors, vectors=[embedding.embed_text(email.body_raw)],
payloads=payloads, payloads=[email.as_payload()],
) )
vector_ids = [f"mail/{vector_id}"]
embeds = defaultdict(list)
for attachment in email.attachments:
if attachment.file_path:
content = pathlib.Path(attachment.file_path).read_bytes()
else:
content = attachment.content
collection, vector = embedding.embed(attachment.content_type, content)
embeds[collection].append((str(uuid.uuid4()), vector, attachment.as_payload()))
for collection, embeds in embeds.items():
ids, vectors, payloads = zip(*embeds)
upsert_vectors(
client=qdrant_client,
collection_name=collection,
ids=ids,
vectors=vectors,
payloads=payloads,
)
vector_ids.extend([f"{collection}/{vector_id}" for vector_id in ids])
logger.info(f"Stored embedding for message {email.message_id}") logger.info(f"Stored embedding for message {email.message_id}")
return vector_ids return vector_ids

View File

@ -1,34 +1,14 @@
import logging import logging
from typing import Any, Literal, TypedDict, cast from typing import Any, cast
import qdrant_client import qdrant_client
from qdrant_client.http import models as qdrant_models from qdrant_client.http import models as qdrant_models
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse
from memory.common import settings from memory.common import settings
from memory.common.embedding import Collection, DEFAULT_COLLECTIONS, DistanceType, Vector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Type of distance metric
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
Vector = list[float]
class Collection(TypedDict):
dimension: int
distance: DistanceType
on_disk: bool
shards: int
DEFAULT_COLLECTIONS: dict[str, Collection] = {
"mail": {"dimension": 1536, "distance": "Cosine"},
# "chat": {"dimension": 1536, "distance": "Cosine"},
# "git": {"dimension": 1536, "distance": "Cosine"},
# "photo": {"dimension": 512, "distance": "Cosine"},
# "book": {"dimension": 1536, "distance": "Cosine"},
# "blog": {"dimension": 1536, "distance": "Cosine"},
# "doc": {"dimension": 1536, "distance": "Cosine"},
}
def get_qdrant_client() -> qdrant_client.QdrantClient: def get_qdrant_client() -> qdrant_client.QdrantClient:
"""Create and return a Qdrant client using environment configuration.""" """Create and return a Qdrant client using environment configuration."""

View File

@ -700,7 +700,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
vector_ids = vectorize_email(mail_message) vector_ids = vectorize_email(mail_message)
assert len(vector_ids) == 1 assert len(vector_ids) == 1
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001" assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4): def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
@ -741,6 +741,9 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
file_path=None, file_path=None,
) )
file_path = mail_message.attachments_path / "stored.txt"
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_bytes(b"This is stored content")
attachment2 = EmailAttachment( attachment2 = EmailAttachment(
id=2, id=2,
mail_message_id=mail_message.id, mail_message_id=mail_message.id,
@ -748,21 +751,19 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
content_type="text/plain", content_type="text/plain",
size=200, size=200,
content=None, content=None,
file_path="/path/to/stored.txt", file_path=str(file_path),
) )
db_session.add_all([attachment1, attachment2]) db_session.add_all([attachment1, attachment2])
db_session.flush() db_session.flush()
# Mock embedding functions but use real qdrant # Mock embedding functions but use real qdrant
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536), \ with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
patch("memory.common.embedding.embed_file", return_value=[0.7] * 1536):
# Call the function # Call the function
vector_ids = vectorize_email(mail_message) vector_ids = vectorize_email(mail_message)
# Verify results # Verify results
assert len(vector_ids) == 3 assert len(vector_ids) == 3
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001" assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
assert vector_ids[1] == "00000000-0000-0000-0000-000000000002" assert vector_ids[1] == "doc/00000000-0000-0000-0000-000000000002"
assert vector_ids[2] == "00000000-0000-0000-0000-000000000003" assert vector_ids[2] == "doc/00000000-0000-0000-0000-000000000003"