mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 21:34:42 +02:00
better vectorization
This commit is contained in:
parent
a104a3211b
commit
3dca666d08
@ -1,3 +1,45 @@
|
|||||||
|
import pathlib
|
||||||
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
|
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
||||||
|
Vector = list[float]
|
||||||
|
|
||||||
|
class Collection(TypedDict):
|
||||||
|
dimension: int
|
||||||
|
distance: DistanceType
|
||||||
|
on_disk: bool
|
||||||
|
shards: int
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
||||||
|
"mail": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
"chat": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
"git": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
"photo": {"dimension": 512, "distance": "Cosine"},
|
||||||
|
"book": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
"blog": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
"doc": {"dimension": 1536, "distance": "Cosine"},
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPES = {
|
||||||
|
"doc": ["text/*"],
|
||||||
|
"photo": ["image/*"],
|
||||||
|
"book": ["application/pdf", "application/epub+zip", "application/mobi", "application/x-mobipocket-ebook"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_type(mime_type: str) -> str:
|
||||||
|
for type, mime_types in TYPES.items():
|
||||||
|
if mime_type in mime_types:
|
||||||
|
return type
|
||||||
|
stem = mime_type.split("/")[0]
|
||||||
|
|
||||||
|
for type, mime_types in TYPES.items():
|
||||||
|
if any(mime_type.startswith(stem) for mime_type in mime_types):
|
||||||
|
return type
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
|
def embed_text(text: str, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Embed a text using OpenAI's API.
|
Embed a text using OpenAI's API.
|
||||||
@ -10,3 +52,9 @@ def embed_file(file_path: str, model: str = "text-embedding-3-small", n_dimensio
|
|||||||
Embed a file using OpenAI's API.
|
Embed a file using OpenAI's API.
|
||||||
"""
|
"""
|
||||||
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
return [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
||||||
|
|
||||||
|
|
||||||
|
def embed(mime_type: str, content: bytes | str | pathlib.Path, model: str = "text-embedding-3-small", n_dimensions: int = 1536) -> tuple[str, list[float]]:
|
||||||
|
collection = get_type(mime_type)
|
||||||
|
|
||||||
|
return collection, [0.0] * n_dimensions # Placeholder n_dimensions-dimensional vector
|
@ -11,7 +11,7 @@ from email.utils import parsedate_to_datetime
|
|||||||
from typing import Generator, Callable, TypedDict, Literal
|
from typing import Generator, Callable, TypedDict, Literal
|
||||||
import pathlib
|
import pathlib
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from collections import defaultdict
|
||||||
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
from memory.common.db.models import EmailAccount, MailMessage, SourceItem, EmailAttachment
|
||||||
from memory.common import settings, embedding
|
from memory.common import settings, embedding
|
||||||
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
|
from memory.workers.qdrant import get_qdrant_client, upsert_vectors
|
||||||
@ -462,27 +462,35 @@ def imap_connection(account: EmailAccount) -> Generator[imaplib.IMAP4_SSL, None,
|
|||||||
def vectorize_email(email: MailMessage) -> list[float]:
|
def vectorize_email(email: MailMessage) -> list[float]:
|
||||||
qdrant_client = get_qdrant_client()
|
qdrant_client = get_qdrant_client()
|
||||||
|
|
||||||
vector_ids = [str(uuid.uuid4())]
|
vector_id = uuid.uuid4()
|
||||||
vectors = [embedding.embed_text(email.body_raw)]
|
|
||||||
payloads = [email.as_payload()]
|
|
||||||
|
|
||||||
for attachment in email.attachments:
|
|
||||||
vector_ids.append(str(uuid.uuid4()))
|
|
||||||
payloads.append(attachment.as_payload())
|
|
||||||
|
|
||||||
if attachment.file_path:
|
|
||||||
vector = embedding.embed_file(attachment.file_path)
|
|
||||||
else:
|
|
||||||
vector = embedding.embed_text(attachment.content)
|
|
||||||
vectors.append(vector)
|
|
||||||
|
|
||||||
upsert_vectors(
|
upsert_vectors(
|
||||||
client=qdrant_client,
|
client=qdrant_client,
|
||||||
collection_name="mail",
|
collection_name="mail",
|
||||||
ids=vector_ids,
|
ids=[str(vector_id)],
|
||||||
|
vectors=[embedding.embed_text(email.body_raw)],
|
||||||
|
payloads=[email.as_payload()],
|
||||||
|
)
|
||||||
|
vector_ids = [f"mail/{vector_id}"]
|
||||||
|
|
||||||
|
embeds = defaultdict(list)
|
||||||
|
for attachment in email.attachments:
|
||||||
|
if attachment.file_path:
|
||||||
|
content = pathlib.Path(attachment.file_path).read_bytes()
|
||||||
|
else:
|
||||||
|
content = attachment.content
|
||||||
|
collection, vector = embedding.embed(attachment.content_type, content)
|
||||||
|
embeds[collection].append((str(uuid.uuid4()), vector, attachment.as_payload()))
|
||||||
|
|
||||||
|
for collection, embeds in embeds.items():
|
||||||
|
ids, vectors, payloads = zip(*embeds)
|
||||||
|
upsert_vectors(
|
||||||
|
client=qdrant_client,
|
||||||
|
collection_name=collection,
|
||||||
|
ids=ids,
|
||||||
vectors=vectors,
|
vectors=vectors,
|
||||||
payloads=payloads,
|
payloads=payloads,
|
||||||
)
|
)
|
||||||
|
vector_ids.extend([f"{collection}/{vector_id}" for vector_id in ids])
|
||||||
|
|
||||||
logger.info(f"Stored embedding for message {email.message_id}")
|
logger.info(f"Stored embedding for message {email.message_id}")
|
||||||
return vector_ids
|
return vector_ids
|
||||||
|
@ -1,34 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal, TypedDict, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
from memory.common.embedding import Collection, DEFAULT_COLLECTIONS, DistanceType, Vector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Type of distance metric
|
|
||||||
DistanceType = Literal["Cosine", "Dot", "Euclidean"]
|
|
||||||
Vector = list[float]
|
|
||||||
|
|
||||||
class Collection(TypedDict):
|
|
||||||
dimension: int
|
|
||||||
distance: DistanceType
|
|
||||||
on_disk: bool
|
|
||||||
shards: int
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_COLLECTIONS: dict[str, Collection] = {
|
|
||||||
"mail": {"dimension": 1536, "distance": "Cosine"},
|
|
||||||
# "chat": {"dimension": 1536, "distance": "Cosine"},
|
|
||||||
# "git": {"dimension": 1536, "distance": "Cosine"},
|
|
||||||
# "photo": {"dimension": 512, "distance": "Cosine"},
|
|
||||||
# "book": {"dimension": 1536, "distance": "Cosine"},
|
|
||||||
# "blog": {"dimension": 1536, "distance": "Cosine"},
|
|
||||||
# "doc": {"dimension": 1536, "distance": "Cosine"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_qdrant_client() -> qdrant_client.QdrantClient:
|
def get_qdrant_client() -> qdrant_client.QdrantClient:
|
||||||
"""Create and return a Qdrant client using environment configuration."""
|
"""Create and return a Qdrant client using environment configuration."""
|
||||||
|
@ -700,7 +700,7 @@ def test_vectorize_email_basic(db_session, qdrant, mock_uuid4):
|
|||||||
vector_ids = vectorize_email(mail_message)
|
vector_ids = vectorize_email(mail_message)
|
||||||
|
|
||||||
assert len(vector_ids) == 1
|
assert len(vector_ids) == 1
|
||||||
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
|
assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
|
||||||
|
|
||||||
|
|
||||||
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
||||||
@ -741,6 +741,9 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
|||||||
file_path=None,
|
file_path=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
file_path = mail_message.attachments_path / "stored.txt"
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
file_path.write_bytes(b"This is stored content")
|
||||||
attachment2 = EmailAttachment(
|
attachment2 = EmailAttachment(
|
||||||
id=2,
|
id=2,
|
||||||
mail_message_id=mail_message.id,
|
mail_message_id=mail_message.id,
|
||||||
@ -748,21 +751,19 @@ def test_vectorize_email_with_attachments(db_session, qdrant, mock_uuid4):
|
|||||||
content_type="text/plain",
|
content_type="text/plain",
|
||||||
size=200,
|
size=200,
|
||||||
content=None,
|
content=None,
|
||||||
file_path="/path/to/stored.txt",
|
file_path=str(file_path),
|
||||||
)
|
)
|
||||||
|
|
||||||
db_session.add_all([attachment1, attachment2])
|
db_session.add_all([attachment1, attachment2])
|
||||||
db_session.flush()
|
db_session.flush()
|
||||||
|
|
||||||
# Mock embedding functions but use real qdrant
|
# Mock embedding functions but use real qdrant
|
||||||
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536), \
|
with patch("memory.common.embedding.embed_text", return_value=[0.1] * 1536):
|
||||||
patch("memory.common.embedding.embed_file", return_value=[0.7] * 1536):
|
|
||||||
|
|
||||||
# Call the function
|
# Call the function
|
||||||
vector_ids = vectorize_email(mail_message)
|
vector_ids = vectorize_email(mail_message)
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert len(vector_ids) == 3
|
assert len(vector_ids) == 3
|
||||||
assert vector_ids[0] == "00000000-0000-0000-0000-000000000001"
|
assert vector_ids[0] == "mail/00000000-0000-0000-0000-000000000001"
|
||||||
assert vector_ids[1] == "00000000-0000-0000-0000-000000000002"
|
assert vector_ids[1] == "doc/00000000-0000-0000-0000-000000000002"
|
||||||
assert vector_ids[2] == "00000000-0000-0000-0000-000000000003"
|
assert vector_ids[2] == "doc/00000000-0000-0000-0000-000000000003"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user