memory/tests/memory/common/test_embedding.py
2025-06-02 14:34:39 +02:00

55 lines
1.5 KiB
Python

from unittest.mock import Mock
import pytest
from memory.common import collections
from memory.common.embedding import (
embed_mixed,
embed_text,
)
from memory.common.extract import DataChunk
@pytest.fixture
def mock_embed(mock_voyage_client):
vectors = ([i] for i in range(1000))
def embed(texts, model, input_type):
return Mock(embeddings=[next(vectors) for _ in texts])
mock_voyage_client.embed = embed
mock_voyage_client.multimodal_embed = embed
return mock_voyage_client
@pytest.mark.parametrize(
"mime_type, expected_modality",
[
("text/plain", "text"),
("text/html", "blog"),
("image/jpeg", "photo"),
("image/png", "photo"),
("application/pdf", "doc"),
("application/epub+zip", "book"),
("application/mobi", "book"),
("application/x-mobipocket-ebook", "book"),
("audio/mp3", "unknown"),
("video/mp4", "unknown"),
("text/something-new", "text"), # Should match by 'text/' stem
("image/something-new", "photo"), # Should match by 'image/' stem
("custom/format", "unknown"), # No matching stem
],
)
def test_get_modality(mime_type, expected_modality):
assert collections.get_modality(mime_type) == expected_modality
def test_embed_text(mock_embed):
chunks = [DataChunk(data=["text1 with words"]), DataChunk(data=["text2"])]
assert embed_text(chunks) == [[0], [1]]
def test_embed_mixed(mock_embed):
items = [DataChunk(data=["text"])]
assert embed_mixed(items) == [[0]]