simplify embedding

This commit is contained in:
Daniel O'Connell 2025-05-26 02:02:50 +02:00
parent 9f1632555b
commit a5618f3543
5 changed files with 116 additions and 501 deletions

View File

@ -8,6 +8,7 @@ import re
import textwrap
from datetime import datetime
from typing import Any, ClassVar, Iterable, Sequence, cast
import uuid
from PIL import Image
from sqlalchemy import (
@ -88,6 +89,19 @@ def clean_filename(filename: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
for i, image in enumerate(images):
if not image.filename: # type: ignore
filename = f"{chunk_id}_{i}.{image.format}" # type: ignore
image.save(filename)
return [image.filename for image in images] # type: ignore
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
return [chunk] + [i for i in images if i.filename in chunk] # type: ignore
class Chunk(Base):
"""Stores content chunks with their vector embeddings."""
@ -107,6 +121,7 @@ class Chunk(Base):
checked_at = Column(DateTime(timezone=True), server_default=func.now())
vector: ClassVar[list[float] | None] = None
item_metadata: ClassVar[dict[str, Any] | None] = None
images: list[Image.Image] = []
# One of file_path or content must be populated
__table_args__ = (
@ -119,11 +134,8 @@ class Chunk(Base):
if self.file_path is None:
return [cast(str, self.content)]
path = pathlib.Path(self.file_path.replace("/app/", ""))
if cast(str, self.file_path).endswith("*"):
files = list(path.parent.glob(path.name))
else:
files = [path]
paths = [pathlib.Path(p) for p in self.file_path.split("\n")]
files = [path for path in paths if path.exists()]
items = []
for file_path in files:
@ -179,13 +191,37 @@ class SourceItem(Base):
"""Get vector IDs from associated chunks."""
return [chunk.id for chunk in self.chunks]
def data_chunks(self) -> Iterable[extract.DataChunk]:
return [
extract.DataChunk(
data=cast(str, self.content),
collection=cast(str, self.modality),
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
chunks: list[list[extract.MulitmodalChunk]] = []
if cast(str | None, self.content):
chunks = [[c] for c in chunker.chunk_text(cast(str, self.content))]
mime_type = cast(str | None, self.mime_type)
if mime_type and mime_type.startswith("image/"):
chunks.append([Image.open(self.filename)])
return chunks
def _make_chunk(
self, data: Sequence[extract.MulitmodalChunk], metadata: dict[str, Any] = {}
):
chunk_id = str(uuid.uuid4())
text = "\n\n".join(c for c in data if isinstance(c, str))
images = [c for c in data if isinstance(c, Image.Image)]
image_names = image_filenames(chunk_id, images)
chunk = Chunk(
id=chunk_id,
source=self,
content=text,
images=images,
file_path="\n".join(image_names) if image_names else None,
embedding_model=collections.collection_model(cast(str, self.modality)),
item_metadata=self.as_payload() | metadata,
)
]
return chunk
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
def as_payload(self) -> dict:
return {
@ -319,18 +355,14 @@ class EmailAttachment(SourceItem):
"tags": self.tags,
}
def data_chunks(self) -> Iterable[extract.DataChunk]:
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
if cast(str | None, self.filename):
contents = pathlib.Path(cast(str, self.filename)).read_bytes()
else:
contents = cast(str, self.content)
return extract.extract_data_chunks(
cast(str, self.mime_type),
contents,
collection=cast(str, self.modality),
embedding_model=collections.collection_model(cast(str, self.modality)),
)
chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents)
return [self._make_chunk(c.data, metadata | c.metadata) for c in chunks]
# Add indexes
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
@ -433,14 +465,10 @@ class Comic(SourceItem):
}
return {k: v for k, v in payload.items() if v is not None}
def data_chunks(self) -> Iterable[extract.DataChunk]:
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
image = Image.open(pathlib.Path(cast(str, self.filename)))
return [
extract.DataChunk(
data=[image, cast(str, self.title), cast(str, self.author)],
collection=cast(str, self.modality),
)
]
description = f"{self.title} by {self.author}"
return [[image, description]]
class Book(Base):
@ -538,16 +566,11 @@ class BookSection(SourceItem):
"tags": self.tags,
}
def data_chunks(self) -> Iterable[extract.DataChunk]:
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)]
texts += [(cast(str, self.content), self.start_page)]
return [
extract.DataChunk(
data=[text],
collection=cast(str, self.modality),
metadata={"page": page_number},
max_size=chunker.EMBEDDING_MAX_TOKENS,
)
self._make_chunk([text.strip()], metadata | {"page": page_number})
for text, page_number in texts
]
@ -601,29 +624,18 @@ class BlogPost(SourceItem):
}
return {k: v for k, v in payload.items() if v}
def data_chunks(self) -> Iterable[extract.DataChunk]:
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
images = [Image.open(image) for image in self.images]
data = [cast(str, self.content)] + images
# Always embed the full content as a single chunk (if possible, of course)
chunks = [
extract.DataChunk(
data=data,
collection=cast(str, self.modality),
max_size=chunker.EMBEDDING_MAX_TOKENS,
)
]
content = cast(str, self.content)
full_text = [content, *images]
# If the content is long enough, also embed it as chunks of the default size.
tokens = chunker.approx_token_count(cast(str, self.content))
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks += [
extract.DataChunk(
data=data,
collection=cast(str, self.modality),
)
]
return chunks
if tokens < chunker.DEFAULT_CHUNK_TOKENS * 2:
return [full_text]
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
return [full_text] + chunks
class MiscDoc(SourceItem):

View File

@ -1,17 +1,16 @@
from collections.abc import Sequence
import logging
import pathlib
import uuid
from typing import Any, Iterable, Literal, cast
from typing import Iterable, Literal, cast
import voyageai
from PIL import Image
from memory.common import extract, settings
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
from memory.common.collections import ALL_COLLECTIONS, Vector
from memory.common.chunker import (
DEFAULT_CHUNK_TOKENS,
OVERLAP_TOKENS,
chunk_text,
)
from memory.common.collections import Vector
from memory.common.db.models import Chunk, SourceItem
from memory.common.extract import DataChunk
logger = logging.getLogger(__name__)
@ -72,91 +71,18 @@ def embed_mixed(
return embed_chunks([chunks], model, input_type)
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
if isinstance(item, str):
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt"
filename.write_text(item)
elif isinstance(item, bytes):
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
filename.write_bytes(item)
elif isinstance(item, Image.Image):
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
item.save(filename)
else:
raise ValueError(f"Unsupported content type: {type(item)}")
return filename
def make_chunk(
contents: Sequence[extract.MulitmodalChunk],
vector: Vector,
metadata: dict[str, Any] = {},
) -> Chunk:
"""Create a Chunk object from a page and a vector.
This is quite complex, because we need to handle the case where the page is a single string,
a single image, or a list of strings and images.
"""
chunk_id = str(uuid.uuid4())
content, filename = None, None
if all(isinstance(c, str) for c in contents):
content = "\n\n".join(cast(list[str], contents))
model = settings.TEXT_EMBEDDING_MODEL
elif len(contents) == 1:
filename = write_to_file(chunk_id, contents[0]).absolute().as_posix()
model = settings.MIXED_EMBEDDING_MODEL
else:
for i, item in enumerate(contents):
write_to_file(f"{chunk_id}_{i}", item)
model = settings.MIXED_EMBEDDING_MODEL
filename = (settings.CHUNK_STORAGE_DIR / f"{chunk_id}_*").absolute().as_posix()
return Chunk(
id=chunk_id,
file_path=filename,
content=content,
embedding_model=model,
vector=vector,
item_metadata=metadata,
)
def embed_data_chunk(
chunk: DataChunk,
metadata: dict[str, Any] = {},
chunk_size: int | None = None,
) -> list[Chunk]:
chunk_size = chunk.max_size or chunk_size or DEFAULT_CHUNK_TOKENS
model = chunk.embedding_model
if not model and chunk.collection:
model = ALL_COLLECTIONS.get(chunk.collection, {}).get("model")
if not model:
model = settings.TEXT_EMBEDDING_MODEL
def embed_chunk(chunk: Chunk) -> Chunk:
model = cast(str, chunk.embedding_model)
if model == settings.TEXT_EMBEDDING_MODEL:
vectors = embed_text(cast(list[str], chunk.data), chunk_size=chunk_size)
content = cast(str, chunk.content)
elif model == settings.MIXED_EMBEDDING_MODEL:
vectors = embed_mixed(
cast(list[extract.MulitmodalChunk], chunk.data),
chunk_size=chunk_size,
)
content = [cast(str, chunk.content)] + chunk.images
else:
raise ValueError(f"Unsupported model: {model}")
metadata = metadata | chunk.metadata
return [make_chunk(chunk.data, vector, metadata) for vector in vectors]
raise ValueError(f"Unsupported model: {chunk.embedding_model}")
vectors = embed_chunks([content], model) # type: ignore
chunk.vector = vectors[0] # type: ignore
return chunk
def embed_source_item(
item: SourceItem,
metadata: dict[str, Any] = {},
chunk_size: int | None = None,
) -> list[Chunk]:
return [
chunk
for data_chunk in item.data_chunks()
for chunk in embed_data_chunk(
data_chunk, item.as_payload() | metadata, chunk_size
)
]
def embed_source_item(item: SourceItem) -> list[Chunk]:
return [embed_chunk(chunk) for chunk in item.data_chunks()]

View File

@ -4,8 +4,9 @@ import logging
import pathlib
import tempfile
from contextlib import contextmanager
from typing import Any, Generator, NotRequired, Sequence, TypedDict, cast
from typing import Any, Generator, Sequence, cast
from memory.common import chunker
import pymupdf # PyMuPDF
import pypandoc
from PIL import Image
@ -15,19 +16,9 @@ logger = logging.getLogger(__name__)
MulitmodalChunk = Image.Image | str
class Page(TypedDict):
contents: Sequence[MulitmodalChunk]
metadata: dict[str, Any]
# This is used to override the default chunk size for the page
chunk_size: NotRequired[int]
@dataclass
class DataChunk:
data: Sequence[MulitmodalChunk]
collection: str | None = None
embedding_model: str | None = None
max_size: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@ -48,18 +39,18 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]:
def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
with as_file(content) as file_path:
with pymupdf.open(file_path) as pdf:
return [
{
"contents": [page_to_image(page)],
"metadata": {
DataChunk(
data=[page_to_image(page)],
metadata={
"page": page.number,
"width": page.rect.width,
"height": page.rect.height,
},
}
)
for page in pdf.pages()
]
@ -93,7 +84,7 @@ def docx_to_pdf(
raise
def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]:
def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[DataChunk]:
"""Extract content from DOCX by converting to PDF first, then processing"""
with as_file(docx_path) as file_path:
pdf_path = docx_to_pdf(file_path)
@ -101,57 +92,47 @@ def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]:
return doc_to_images(pdf_path)
def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
def extract_image(content: bytes | str | pathlib.Path) -> list[DataChunk]:
if isinstance(content, pathlib.Path):
image = Image.open(content)
elif isinstance(content, bytes):
image = Image.open(io.BytesIO(content))
else:
raise ValueError(f"Unsupported content type: {type(content)}")
return [{"contents": [image], "metadata": {}}]
return [DataChunk(data=[image])]
def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
def extract_text(
content: bytes | str | pathlib.Path, chunk_size: int | None = None
) -> list[DataChunk]:
if isinstance(content, pathlib.Path):
content = content.read_text()
if isinstance(content, bytes):
content = content.decode("utf-8")
return [{"contents": [cast(str, content)], "metadata": {}}]
content = cast(str, content)
chunks = chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
return [DataChunk(data=[c]) for c in chunks]
def extract_data_chunks(
mime_type: str,
content: bytes | str | pathlib.Path,
collection: str | None = None,
embedding_model: str | None = None,
chunk_size: int | None = None,
) -> list[DataChunk]:
pages = []
chunks = []
logger.info(f"Extracting content from {mime_type}")
if mime_type == "application/pdf":
pages = doc_to_images(content)
chunks = doc_to_images(content)
elif mime_type in [
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/msword",
]:
logger.info(f"Extracting content from {content}")
pages = extract_docx(content)
logger.info(f"Extracted {len(pages)} pages from {content}")
chunks = extract_docx(content)
logger.info(f"Extracted {len(chunks)} pages from {content}")
elif mime_type.startswith("text/"):
pages = extract_text(content)
chunks = extract_text(content, chunk_size)
elif mime_type.startswith("image/"):
pages = extract_image(content)
if chunk_size:
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
return [
DataChunk(
data=page["contents"],
collection=collection,
embedding_model=embedding_model,
max_size=chunk_size,
)
for page in pages
]
chunks = extract_image(content)
return chunks

View File

@ -1,16 +1,10 @@
import pathlib
import uuid
from unittest.mock import Mock, patch
from typing import cast
from unittest.mock import Mock
import pytest
from PIL import Image
from memory.common import settings, collections
from memory.common import collections
from memory.common.embedding import (
embed_mixed,
embed_text,
make_chunk,
write_to_file,
)
@ -57,302 +51,3 @@ def test_embed_text(mock_embed):
def test_embed_mixed(mock_embed):
items = ["text", {"type": "image", "data": "base64"}]
assert embed_mixed(items) == [[0]]
def test_write_to_file_text(mock_file_storage):
"""Test writing a string to a file."""
chunk_id = "test-chunk-id"
content = "This is a test string"
file_path = write_to_file(chunk_id, content)
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt"
assert file_path.exists()
assert file_path.read_text() == content
def test_write_to_file_bytes(mock_file_storage):
"""Test writing bytes to a file."""
chunk_id = "test-chunk-id"
content = b"These are test bytes"
file_path = write_to_file(chunk_id, content) # type: ignore
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
assert file_path.exists()
assert file_path.read_bytes() == content
def test_write_to_file_image(mock_file_storage):
"""Test writing an image to a file."""
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
chunk_id = "test-chunk-id"
file_path = write_to_file(chunk_id, img) # type: ignore
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
assert file_path.exists()
# Verify it's a valid image file by opening it
image = Image.open(file_path)
assert image.size == (100, 100)
def test_write_to_file_unsupported_type(mock_file_storage):
"""Test that an error is raised for unsupported content types."""
chunk_id = "test-chunk-id"
content = 123 # Integer is not a supported type
with pytest.raises(ValueError, match="Unsupported content type"):
write_to_file(chunk_id, content) # type: ignore
def test_make_chunk_text_only(mock_file_storage, db_session):
"""Test creating a chunk from string content."""
contents = ["text content 1", "text content 2"]
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
):
chunk = make_chunk(contents, vector, metadata) # type: ignore
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
assert chunk.file_path is None
assert cast(str, chunk.embedding_model) == settings.TEXT_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata
def test_make_chunk_single_image(mock_file_storage, db_session):
"""Test creating a chunk from a single image."""
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
contents = [img]
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
):
chunk = make_chunk(contents, vector, metadata) # type: ignore
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
assert chunk.content is None
assert cast(str, chunk.file_path) == str(
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
)
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata
# Verify the file exists
assert pathlib.Path(cast(str, chunk.file_path)).exists()
def test_make_chunk_mixed_content(mock_file_storage, db_session):
"""Test creating a chunk from mixed content (string and image)."""
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
contents = ["text content", img]
vector = [0.1, 0.2, 0.3]
metadata = {"doc_type": "test", "source": "unit-test"}
with patch.object(
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
):
chunk = make_chunk(contents, vector, metadata) # type: ignore
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
assert chunk.content is None
assert cast(str, chunk.file_path) == str(
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
)
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
assert chunk.vector == vector
assert chunk.item_metadata == metadata
# Verify the files exist
assert (
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_0.txt"
).exists()
assert (
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png"
).exists()
@pytest.mark.parametrize(
"data,embedding_model,collection,expected_model,expected_count,expected_has_content",
[
# Text-only with default model
(
["text content 1", "text content 2"],
None,
None,
settings.TEXT_EMBEDDING_MODEL,
2,
True,
),
# Text with explicit mixed model - but make_chunk still uses TEXT_EMBEDDING_MODEL for text-only content
(
["text content"],
settings.MIXED_EMBEDDING_MODEL,
None,
settings.TEXT_EMBEDDING_MODEL,
1,
True,
),
# Text collection model selection - make_chunk uses TEXT_EMBEDDING_MODEL for text-only content
(["text content"], None, "mail", settings.TEXT_EMBEDDING_MODEL, 1, True),
(["text content"], None, "photo", settings.TEXT_EMBEDDING_MODEL, 1, True),
(["text content"], None, "doc", settings.TEXT_EMBEDDING_MODEL, 1, True),
# Unknown collection falls back to default
(["text content"], None, "unknown", settings.TEXT_EMBEDDING_MODEL, 1, True),
# Explicit model takes precedence over collection
(
["text content"],
settings.TEXT_EMBEDDING_MODEL,
"photo",
settings.TEXT_EMBEDDING_MODEL,
1,
True,
),
],
)
def test_embed_data_chunk_scenarios(
data,
embedding_model,
collection,
expected_model,
expected_count,
expected_has_content,
mock_embed,
mock_file_storage,
):
"""Test various embedding scenarios for data chunks."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(
data=data,
embedding_model=embedding_model,
collection=collection,
metadata={"source": "test"},
)
result = embed_data_chunk(chunk, {"doc_type": "test"})
assert len(result) == expected_count
assert all(cast(str, c.embedding_model) == expected_model for c in result)
if expected_has_content:
assert all(c.content is not None for c in result)
assert all(c.file_path is None for c in result)
else:
assert all(c.content is None for c in result)
assert all(c.file_path is not None for c in result)
assert all(
c.item_metadata == {"source": "test", "doc_type": "test"} for c in result
)
def test_embed_data_chunk_mixed_content(mock_embed, mock_file_storage):
"""Test embedding mixed content (text and images)."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
chunk = DataChunk(
data=["text content", img],
embedding_model=settings.MIXED_EMBEDDING_MODEL,
metadata={"source": "test"},
)
result = embed_data_chunk(chunk)
assert len(result) == 1 # Mixed content returns single vector
assert result[0].content is None # Mixed content stored in files
assert result[0].file_path is not None
assert cast(str, result[0].embedding_model) == settings.MIXED_EMBEDDING_MODEL
@pytest.mark.parametrize(
"chunk_max_size,chunk_size_param,expected_chunk_size",
[
(512, 1024, 512), # chunk.max_size takes precedence
(None, 2048, 2048), # chunk_size parameter used when max_size is None
(256, None, 256), # chunk.max_size used when parameter is None
],
)
def test_embed_data_chunk_chunk_size_handling(
chunk_max_size, chunk_size_param, expected_chunk_size, mock_embed, mock_file_storage
):
"""Test chunk size parameter handling."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(
data=["text content"], max_size=chunk_max_size, metadata={"source": "test"}
)
with patch("memory.common.embedding.embed_text") as mock_embed_text:
mock_embed_text.return_value = [[0.1, 0.2, 0.3]]
result = embed_data_chunk(chunk, chunk_size=chunk_size_param)
mock_embed_text.assert_called_once()
args, kwargs = mock_embed_text.call_args
assert kwargs["chunk_size"] == expected_chunk_size
def test_embed_data_chunk_metadata_merging(mock_embed, mock_file_storage):
"""Test that chunk metadata and parameter metadata are properly merged."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(
data=["text content"], metadata={"source": "test", "type": "chunk"}
)
metadata = {
"doc_type": "test",
"source": "override",
} # chunk.metadata takes precedence over parameter metadata
result = embed_data_chunk(chunk, metadata)
assert len(result) == 1
expected_metadata = {
"source": "test",
"type": "chunk",
"doc_type": "test",
} # chunk source wins
assert result[0].item_metadata == expected_metadata
def test_embed_data_chunk_unsupported_model(mock_embed, mock_file_storage):
"""Test error handling for unsupported embedding model."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(
data=["text content"],
embedding_model="unsupported-model",
metadata={"source": "test"},
)
with pytest.raises(ValueError, match="Unsupported model: unsupported-model"):
embed_data_chunk(chunk)
def test_embed_data_chunk_empty_data(mock_embed, mock_file_storage):
"""Test handling of empty data."""
from memory.common.extract import DataChunk
from memory.common.embedding import embed_data_chunk
chunk = DataChunk(data=[], metadata={"source": "test"})
# Should handle empty data gracefully
with patch("memory.common.embedding.embed_text") as mock_embed_text:
mock_embed_text.return_value = []
result = embed_data_chunk(chunk)
assert result == []

View File

@ -11,6 +11,7 @@ from memory.common.extract import (
extract_image,
docx_to_pdf,
extract_docx,
DataChunk,
)
@ -47,8 +48,8 @@ def test_as_file_with_str():
@pytest.mark.parametrize(
"input_content,expected",
[
("simple text", [{"contents": ["simple text"], "metadata": {}}]),
(b"bytes text", [{"contents": ["bytes text"], "metadata": {}}]),
("simple text", [DataChunk(data=["simple text"], metadata={})]),
(b"bytes text", [DataChunk(data=["bytes text"], metadata={})]),
],
)
def test_extract_text(input_content, expected):
@ -60,7 +61,7 @@ def test_extract_text_with_path(tmp_path):
test_file.write_text("file text content")
assert extract_text(test_file) == [
{"contents": ["file text content"], "metadata": {}}
DataChunk(data=["file text content"], metadata={})
]
@ -72,8 +73,8 @@ def test_doc_to_images():
for page, pdf_page in zip(result, pdf.pages()):
pix = pdf_page.get_pixmap()
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
assert page["contents"] == [img]
assert page["metadata"] == {
assert page.data == [img]
assert page.metadata == {
"page": pdf_page.number,
"width": pdf_page.rect.width,
"height": pdf_page.rect.height,
@ -86,8 +87,8 @@ def test_extract_image_with_path(tmp_path):
img.save(img_path)
(page,) = extract_image(img_path)
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
assert page["metadata"] == {}
assert page.data[0].tobytes() == img.convert("RGB").tobytes() # type: ignore
assert page.metadata == {}
def test_extract_image_with_bytes():
@ -97,8 +98,8 @@ def test_extract_image_with_bytes():
img_bytes = buffer.getvalue()
(page,) = extract_image(img_bytes)
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
assert page["metadata"] == {}
assert page.data[0].tobytes() == img.convert("RGB").tobytes() # type: ignore
assert page.metadata == {}
def test_extract_image_with_str():