From a5618f3543150d0b6ecca4074264f75fd9c9cc06 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 26 May 2025 02:02:50 +0200 Subject: [PATCH] simplify embedding --- src/memory/common/db/models.py | 116 +++++----- src/memory/common/embedding.py | 108 ++------- src/memory/common/extract.py | 65 ++---- tests/memory/common/test_embedding.py | 309 +------------------------- tests/memory/common/test_extract.py | 19 +- 5 files changed, 116 insertions(+), 501 deletions(-) diff --git a/src/memory/common/db/models.py b/src/memory/common/db/models.py index b7cca39..00851f2 100644 --- a/src/memory/common/db/models.py +++ b/src/memory/common/db/models.py @@ -8,6 +8,7 @@ import re import textwrap from datetime import datetime from typing import Any, ClassVar, Iterable, Sequence, cast +import uuid from PIL import Image from sqlalchemy import ( @@ -88,6 +89,19 @@ def clean_filename(filename: str) -> str: return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_") +def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]: + for i, image in enumerate(images): + if not image.filename: # type: ignore + filename = f"{chunk_id}_{i}.{image.format}" # type: ignore + image.save(filename) + + return [image.filename for image in images] # type: ignore + + +def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]: + return [chunk] + [i for i in images if i.filename in chunk] # type: ignore + + class Chunk(Base): """Stores content chunks with their vector embeddings.""" @@ -107,6 +121,7 @@ class Chunk(Base): checked_at = Column(DateTime(timezone=True), server_default=func.now()) vector: ClassVar[list[float] | None] = None item_metadata: ClassVar[dict[str, Any] | None] = None + images: list[Image.Image] = [] # One of file_path or content must be populated __table_args__ = ( @@ -119,11 +134,8 @@ class Chunk(Base): if self.file_path is None: return [cast(str, self.content)] - path = pathlib.Path(self.file_path.replace("/app/", "")) - if cast(str, self.file_path).endswith("*"): - files = list(path.parent.glob(path.name)) - else: - files = [path] + paths = [pathlib.Path(p) for p in self.file_path.split("\n")] + files = [path for path in paths if path.exists()] items = [] for file_path in files: @@ -179,13 +191,37 @@ class SourceItem(Base): """Get vector IDs from associated chunks.""" return [chunk.id for chunk in self.chunks] - def data_chunks(self) -> Iterable[extract.DataChunk]: - return [ - extract.DataChunk( - data=cast(str, self.content), - collection=cast(str, self.modality), - ) - ] + def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: + chunks: list[list[extract.MulitmodalChunk]] = [] + if cast(str | None, self.content): + chunks = [[c] for c in chunker.chunk_text(cast(str, self.content))] + + mime_type = cast(str | None, self.mime_type) + if mime_type and mime_type.startswith("image/"): + chunks.append([Image.open(self.filename)]) + return chunks + + def _make_chunk( + self, data: Sequence[extract.MulitmodalChunk], metadata: dict[str, Any] = {} + ): + chunk_id = str(uuid.uuid4()) + text = "\n\n".join(c for c in data if isinstance(c, str)) + images = [c for c in data if isinstance(c, Image.Image)] + image_names = image_filenames(chunk_id, images) + + chunk = Chunk( + id=chunk_id, + source=self, + content=text, + images=images, + file_path="\n".join(image_names) if image_names else None, + embedding_model=collections.collection_model(cast(str, self.modality)), + item_metadata=self.as_payload() | metadata, + ) + return chunk + + def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: + return [self._make_chunk(data, metadata) for data in self._chunk_contents()] def as_payload(self) -> dict: return { @@ -319,18 +355,14 @@ class EmailAttachment(SourceItem): "tags": self.tags, } - def data_chunks(self) -> Iterable[extract.DataChunk]: + def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: if cast(str | None, self.filename): contents = pathlib.Path(cast(str, self.filename)).read_bytes() else: contents = cast(str, self.content) - return extract.extract_data_chunks( - cast(str, self.mime_type), - contents, - collection=cast(str, self.modality), - embedding_model=collections.collection_model(cast(str, self.modality)), - ) + chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents) + return [self._make_chunk(c.data, metadata | c.metadata) for c in chunks] # Add indexes __table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),) @@ -433,14 +465,10 @@ class Comic(SourceItem): } return {k: v for k, v in payload.items() if v is not None} - def data_chunks(self) -> Iterable[extract.DataChunk]: + def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: image = Image.open(pathlib.Path(cast(str, self.filename))) - return [ - extract.DataChunk( - data=[image, cast(str, self.title), cast(str, self.author)], - collection=cast(str, self.modality), - ) - ] + description = f"{self.title} by {self.author}" + return [[image, description]] class Book(Base): @@ -538,16 +566,11 @@ class BookSection(SourceItem): "tags": self.tags, } - def data_chunks(self) -> Iterable[extract.DataChunk]: + def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)] texts += [(cast(str, self.content), self.start_page)] return [ - extract.DataChunk( - data=[text], - collection=cast(str, self.modality), - metadata={"page": page_number}, - max_size=chunker.EMBEDDING_MAX_TOKENS, - ) + self._make_chunk([text.strip()], metadata | {"page": page_number}) for text, page_number in texts ] @@ -601,29 +624,18 @@ class BlogPost(SourceItem): } return {k: v for k, v in payload.items() if v} - def data_chunks(self) -> Iterable[extract.DataChunk]: + def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]: images = [Image.open(image) for image in self.images] - data = [cast(str, self.content)] + images - # Always embed the full content as a single chunk (if possible, of course) - chunks = [ - extract.DataChunk( - data=data, - collection=cast(str, self.modality), - max_size=chunker.EMBEDDING_MAX_TOKENS, - ) - ] + content = cast(str, self.content) + full_text = [content, *images] - # If the content is long enough, also embed it as chunks of the default size. tokens = chunker.approx_token_count(cast(str, self.content)) - if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2: - chunks += [ - extract.DataChunk( - data=data, - collection=cast(str, self.modality), - ) - ] - return chunks + if tokens < chunker.DEFAULT_CHUNK_TOKENS * 2: + return [full_text] + + chunks = [add_pics(c, images) for c in chunker.chunk_text(content)] + return [full_text] + chunks class MiscDoc(SourceItem): diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 6eb26d0..95acbc4 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -1,17 +1,16 @@ -from collections.abc import Sequence import logging -import pathlib -import uuid -from typing import Any, Iterable, Literal, cast +from typing import Iterable, Literal, cast import voyageai -from PIL import Image from memory.common import extract, settings -from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS -from memory.common.collections import ALL_COLLECTIONS, Vector +from memory.common.chunker import ( + DEFAULT_CHUNK_TOKENS, + OVERLAP_TOKENS, + chunk_text, +) +from memory.common.collections import Vector from memory.common.db.models import Chunk, SourceItem -from memory.common.extract import DataChunk logger = logging.getLogger(__name__) @@ -72,91 +71,18 @@ def embed_mixed( return embed_chunks([chunks], model, input_type) -def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path: - if isinstance(item, str): - filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt" - filename.write_text(item) - elif isinstance(item, bytes): - filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin" - filename.write_bytes(item) - elif isinstance(item, Image.Image): - filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png" - item.save(filename) - else: - raise ValueError(f"Unsupported content type: {type(item)}") - return filename - - -def make_chunk( - contents: Sequence[extract.MulitmodalChunk], - vector: Vector, - metadata: dict[str, Any] = {}, -) -> Chunk: - """Create a Chunk object from a page and a vector. - - This is quite complex, because we need to handle the case where the page is a single string, - a single image, or a list of strings and images. - """ - chunk_id = str(uuid.uuid4()) - content, filename = None, None - if all(isinstance(c, str) for c in contents): - content = "\n\n".join(cast(list[str], contents)) - model = settings.TEXT_EMBEDDING_MODEL - elif len(contents) == 1: - filename = write_to_file(chunk_id, contents[0]).absolute().as_posix() - model = settings.MIXED_EMBEDDING_MODEL - else: - for i, item in enumerate(contents): - write_to_file(f"{chunk_id}_{i}", item) - model = settings.MIXED_EMBEDDING_MODEL - filename = (settings.CHUNK_STORAGE_DIR / f"{chunk_id}_*").absolute().as_posix() - - return Chunk( - id=chunk_id, - file_path=filename, - content=content, - embedding_model=model, - vector=vector, - item_metadata=metadata, - ) - - -def embed_data_chunk( - chunk: DataChunk, - metadata: dict[str, Any] = {}, - chunk_size: int | None = None, -) -> list[Chunk]: - chunk_size = chunk.max_size or chunk_size or DEFAULT_CHUNK_TOKENS - - model = chunk.embedding_model - if not model and chunk.collection: - model = ALL_COLLECTIONS.get(chunk.collection, {}).get("model") - if not model: - model = settings.TEXT_EMBEDDING_MODEL - +def embed_chunk(chunk: Chunk) -> Chunk: + model = cast(str, chunk.embedding_model) if model == settings.TEXT_EMBEDDING_MODEL: - vectors = embed_text(cast(list[str], chunk.data), chunk_size=chunk_size) + content = cast(str, chunk.content) elif model == settings.MIXED_EMBEDDING_MODEL: - vectors = embed_mixed( - cast(list[extract.MulitmodalChunk], chunk.data), - chunk_size=chunk_size, - ) + content = [cast(str, chunk.content)] + chunk.images else: - raise ValueError(f"Unsupported model: {model}") - - metadata = metadata | chunk.metadata - return [make_chunk(chunk.data, vector, metadata) for vector in vectors] + raise ValueError(f"Unsupported model: {chunk.embedding_model}") + vectors = embed_chunks([content], model) # type: ignore + chunk.vector = vectors[0] # type: ignore + return chunk -def embed_source_item( - item: SourceItem, - metadata: dict[str, Any] = {}, - chunk_size: int | None = None, -) -> list[Chunk]: - return [ - chunk - for data_chunk in item.data_chunks() - for chunk in embed_data_chunk( - data_chunk, item.as_payload() | metadata, chunk_size - ) - ] +def embed_source_item(item: SourceItem) -> list[Chunk]: + return [embed_chunk(chunk) for chunk in item.data_chunks()] diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 7b7f4d2..9631f37 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -4,8 +4,9 @@ import logging import pathlib import tempfile from contextlib import contextmanager -from typing import Any, Generator, NotRequired, Sequence, TypedDict, cast +from typing import Any, Generator, Sequence, cast +from memory.common import chunker import pymupdf # PyMuPDF import pypandoc from PIL import Image @@ -15,19 +16,9 @@ logger = logging.getLogger(__name__) MulitmodalChunk = Image.Image | str -class Page(TypedDict): - contents: Sequence[MulitmodalChunk] - metadata: dict[str, Any] - # This is used to override the default chunk size for the page - chunk_size: NotRequired[int] - - @dataclass class DataChunk: data: Sequence[MulitmodalChunk] - collection: str | None = None - embedding_model: str | None = None - max_size: int | None = None metadata: dict[str, Any] = field(default_factory=dict) @@ -48,18 +39,18 @@ def page_to_image(page: pymupdf.Page) -> Image.Image: return Image.frombytes("RGB", [pix.width, pix.height], pix.samples) -def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]: +def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]: with as_file(content) as file_path: with pymupdf.open(file_path) as pdf: return [ - { - "contents": [page_to_image(page)], - "metadata": { + DataChunk( + data=[page_to_image(page)], + metadata={ "page": page.number, "width": page.rect.width, "height": page.rect.height, }, - } + ) for page in pdf.pages() ] @@ -93,7 +84,7 @@ def docx_to_pdf( raise -def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]: +def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[DataChunk]: """Extract content from DOCX by converting to PDF first, then processing""" with as_file(docx_path) as file_path: pdf_path = docx_to_pdf(file_path) @@ -101,57 +92,47 @@ def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]: return doc_to_images(pdf_path) -def extract_image(content: bytes | str | pathlib.Path) -> list[Page]: +def extract_image(content: bytes | str | pathlib.Path) -> list[DataChunk]: if isinstance(content, pathlib.Path): image = Image.open(content) elif isinstance(content, bytes): image = Image.open(io.BytesIO(content)) else: raise ValueError(f"Unsupported content type: {type(content)}") - return [{"contents": [image], "metadata": {}}] + return [DataChunk(data=[image])] -def extract_text(content: bytes | str | pathlib.Path) -> list[Page]: +def extract_text( + content: bytes | str | pathlib.Path, chunk_size: int | None = None +) -> list[DataChunk]: if isinstance(content, pathlib.Path): content = content.read_text() if isinstance(content, bytes): content = content.decode("utf-8") - return [{"contents": [cast(str, content)], "metadata": {}}] + content = cast(str, content) + chunks = chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS) + return [DataChunk(data=[c]) for c in chunks] def extract_data_chunks( mime_type: str, content: bytes | str | pathlib.Path, - collection: str | None = None, - embedding_model: str | None = None, chunk_size: int | None = None, ) -> list[DataChunk]: - pages = [] + chunks = [] logger.info(f"Extracting content from {mime_type}") if mime_type == "application/pdf": - pages = doc_to_images(content) + chunks = doc_to_images(content) elif mime_type in [ "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/msword", ]: logger.info(f"Extracting content from {content}") - pages = extract_docx(content) - logger.info(f"Extracted {len(pages)} pages from {content}") + chunks = extract_docx(content) + logger.info(f"Extracted {len(chunks)} pages from {content}") elif mime_type.startswith("text/"): - pages = extract_text(content) + chunks = extract_text(content, chunk_size) elif mime_type.startswith("image/"): - pages = extract_image(content) - - if chunk_size: - pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages] - - return [ - DataChunk( - data=page["contents"], - collection=collection, - embedding_model=embedding_model, - max_size=chunk_size, - ) - for page in pages - ] + chunks = extract_image(content) + return chunks diff --git a/tests/memory/common/test_embedding.py b/tests/memory/common/test_embedding.py index f921802..03b5f67 100644 --- a/tests/memory/common/test_embedding.py +++ b/tests/memory/common/test_embedding.py @@ -1,16 +1,10 @@ -import pathlib -import uuid -from unittest.mock import Mock, patch -from typing import cast +from unittest.mock import Mock import pytest -from PIL import Image -from memory.common import settings, collections +from memory.common import collections from memory.common.embedding import ( embed_mixed, embed_text, - make_chunk, - write_to_file, ) @@ -57,302 +51,3 @@ def test_embed_text(mock_embed): def test_embed_mixed(mock_embed): items = ["text", {"type": "image", "data": "base64"}] assert embed_mixed(items) == [[0]] - - -def test_write_to_file_text(mock_file_storage): - """Test writing a string to a file.""" - chunk_id = "test-chunk-id" - content = "This is a test string" - - file_path = write_to_file(chunk_id, content) - - assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt" - assert file_path.exists() - assert file_path.read_text() == content - - -def test_write_to_file_bytes(mock_file_storage): - """Test writing bytes to a file.""" - chunk_id = "test-chunk-id" - content = b"These are test bytes" - - file_path = write_to_file(chunk_id, content) # type: ignore - - assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin" - assert file_path.exists() - assert file_path.read_bytes() == content - - -def test_write_to_file_image(mock_file_storage): - """Test writing an image to a file.""" - img = Image.new("RGB", (100, 100), color=(73, 109, 137)) - chunk_id = "test-chunk-id" - - file_path = write_to_file(chunk_id, img) # type: ignore - - assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png" - assert file_path.exists() - # Verify it's a valid image file by opening it - image = Image.open(file_path) - assert image.size == (100, 100) - - -def test_write_to_file_unsupported_type(mock_file_storage): - """Test that an error is raised for unsupported content types.""" - chunk_id = "test-chunk-id" - content = 123 # Integer is not a supported type - - with pytest.raises(ValueError, match="Unsupported content type"): - write_to_file(chunk_id, content) # type: ignore - - -def test_make_chunk_text_only(mock_file_storage, db_session): - """Test creating a chunk from string content.""" - contents = ["text content 1", "text content 2"] - vector = [0.1, 0.2, 0.3] - metadata = {"doc_type": "test", "source": "unit-test"} - - with patch.object( - uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001") - ): - chunk = make_chunk(contents, vector, metadata) # type: ignore - - assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001" - assert cast(str, chunk.content) == "text content 1\n\ntext content 2" - assert chunk.file_path is None - assert cast(str, chunk.embedding_model) == settings.TEXT_EMBEDDING_MODEL - assert chunk.vector == vector - assert chunk.item_metadata == metadata - - -def test_make_chunk_single_image(mock_file_storage, db_session): - """Test creating a chunk from a single image.""" - img = Image.new("RGB", (100, 100), color=(73, 109, 137)) - contents = [img] - vector = [0.1, 0.2, 0.3] - metadata = {"doc_type": "test", "source": "unit-test"} - - with patch.object( - uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002") - ): - chunk = make_chunk(contents, vector, metadata) # type: ignore - - assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002" - assert chunk.content is None - assert cast(str, chunk.file_path) == str( - settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png", - ) - assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL - assert chunk.vector == vector - assert chunk.item_metadata == metadata - - # Verify the file exists - assert pathlib.Path(cast(str, chunk.file_path)).exists() - - -def test_make_chunk_mixed_content(mock_file_storage, db_session): - """Test creating a chunk from mixed content (string and image).""" - img = Image.new("RGB", (100, 100), color=(73, 109, 137)) - contents = ["text content", img] - vector = [0.1, 0.2, 0.3] - metadata = {"doc_type": "test", "source": "unit-test"} - - with patch.object( - uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003") - ): - chunk = make_chunk(contents, vector, metadata) # type: ignore - - assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003" - assert chunk.content is None - assert cast(str, chunk.file_path) == str( - settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*", - ) - assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL - assert chunk.vector == vector - assert chunk.item_metadata == metadata - - # Verify the files exist - assert ( - settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_0.txt" - ).exists() - assert ( - settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png" - ).exists() - - -@pytest.mark.parametrize( - "data,embedding_model,collection,expected_model,expected_count,expected_has_content", - [ - # Text-only with default model - ( - ["text content 1", "text content 2"], - None, - None, - settings.TEXT_EMBEDDING_MODEL, - 2, - True, - ), - # Text with explicit mixed model - but make_chunk still uses TEXT_EMBEDDING_MODEL for text-only content - ( - ["text content"], - settings.MIXED_EMBEDDING_MODEL, - None, - settings.TEXT_EMBEDDING_MODEL, - 1, - True, - ), - # Text collection model selection - make_chunk uses TEXT_EMBEDDING_MODEL for text-only content - (["text content"], None, "mail", settings.TEXT_EMBEDDING_MODEL, 1, True), - (["text content"], None, "photo", settings.TEXT_EMBEDDING_MODEL, 1, True), - (["text content"], None, "doc", settings.TEXT_EMBEDDING_MODEL, 1, True), - # Unknown collection falls back to default - (["text content"], None, "unknown", settings.TEXT_EMBEDDING_MODEL, 1, True), - # Explicit model takes precedence over collection - ( - ["text content"], - settings.TEXT_EMBEDDING_MODEL, - "photo", - settings.TEXT_EMBEDDING_MODEL, - 1, - True, - ), - ], -) -def test_embed_data_chunk_scenarios( - data, - embedding_model, - collection, - expected_model, - expected_count, - expected_has_content, - mock_embed, - mock_file_storage, -): - """Test various embedding scenarios for data chunks.""" - from memory.common.extract import DataChunk - from memory.common.embedding import embed_data_chunk - - chunk = DataChunk( - data=data, - embedding_model=embedding_model, - collection=collection, - metadata={"source": "test"}, - ) - - result = embed_data_chunk(chunk, {"doc_type": "test"}) - - assert len(result) == expected_count - assert all(cast(str, c.embedding_model) == expected_model for c in result) - if expected_has_content: - assert all(c.content is not None for c in result) - assert all(c.file_path is None for c in result) - else: - assert all(c.content is None for c in result) - assert all(c.file_path is not None for c in result) - assert all( - c.item_metadata == {"source": "test", "doc_type": "test"} for c in result - ) - - -def test_embed_data_chunk_mixed_content(mock_embed, mock_file_storage): - """Test embedding mixed content (text and images).""" - from memory.common.extract import DataChunk - from memory.common.embedding import embed_data_chunk - - img = Image.new("RGB", (100, 100), color=(73, 109, 137)) - chunk = DataChunk( - data=["text content", img], - embedding_model=settings.MIXED_EMBEDDING_MODEL, - metadata={"source": "test"}, - ) - - result = embed_data_chunk(chunk) - - assert len(result) == 1 # Mixed content returns single vector - assert result[0].content is None # Mixed content stored in files - assert result[0].file_path is not None - assert cast(str, result[0].embedding_model) == settings.MIXED_EMBEDDING_MODEL - - -@pytest.mark.parametrize( - "chunk_max_size,chunk_size_param,expected_chunk_size", - [ - (512, 1024, 512), # chunk.max_size takes precedence - (None, 2048, 2048), # chunk_size parameter used when max_size is None - (256, None, 256), # chunk.max_size used when parameter is None - ], -) -def test_embed_data_chunk_chunk_size_handling( - chunk_max_size, chunk_size_param, expected_chunk_size, mock_embed, mock_file_storage -): - """Test chunk size parameter handling.""" - from memory.common.extract import DataChunk - from memory.common.embedding import embed_data_chunk - - chunk = DataChunk( - data=["text content"], max_size=chunk_max_size, metadata={"source": "test"} - ) - - with patch("memory.common.embedding.embed_text") as mock_embed_text: - mock_embed_text.return_value = [[0.1, 0.2, 0.3]] - - result = embed_data_chunk(chunk, chunk_size=chunk_size_param) - - mock_embed_text.assert_called_once() - args, kwargs = mock_embed_text.call_args - assert kwargs["chunk_size"] == expected_chunk_size - - -def test_embed_data_chunk_metadata_merging(mock_embed, mock_file_storage): - """Test that chunk metadata and parameter metadata are properly merged.""" - from memory.common.extract import DataChunk - from memory.common.embedding import embed_data_chunk - - chunk = DataChunk( - data=["text content"], metadata={"source": "test", "type": "chunk"} - ) - metadata = { - "doc_type": "test", - "source": "override", - } # chunk.metadata takes precedence over parameter metadata - - result = embed_data_chunk(chunk, metadata) - - assert len(result) == 1 - expected_metadata = { - "source": "test", - "type": "chunk", - "doc_type": "test", - } # chunk source wins - assert result[0].item_metadata == expected_metadata - - -def test_embed_data_chunk_unsupported_model(mock_embed, mock_file_storage): - """Test error handling for unsupported embedding model.""" - from memory.common.extract import DataChunk - from memory.common.embedding import embed_data_chunk - - chunk = DataChunk( - data=["text content"], - embedding_model="unsupported-model", - metadata={"source": "test"}, - ) - - with pytest.raises(ValueError, match="Unsupported model: unsupported-model"): - embed_data_chunk(chunk) - - -def test_embed_data_chunk_empty_data(mock_embed, mock_file_storage): - """Test handling of empty data.""" - from memory.common.extract import DataChunk - from memory.common.embedding import embed_data_chunk - - chunk = DataChunk(data=[], metadata={"source": "test"}) - - # Should handle empty data gracefully - with patch("memory.common.embedding.embed_text") as mock_embed_text: - mock_embed_text.return_value = [] - - result = embed_data_chunk(chunk) - - assert result == [] diff --git a/tests/memory/common/test_extract.py b/tests/memory/common/test_extract.py index 6a40bdc..bf03725 100644 --- a/tests/memory/common/test_extract.py +++ b/tests/memory/common/test_extract.py @@ -11,6 +11,7 @@ from memory.common.extract import ( extract_image, docx_to_pdf, extract_docx, + DataChunk, ) @@ -47,8 +48,8 @@ def test_as_file_with_str(): @pytest.mark.parametrize( "input_content,expected", [ - ("simple text", [{"contents": ["simple text"], "metadata": {}}]), - (b"bytes text", [{"contents": ["bytes text"], "metadata": {}}]), + ("simple text", [DataChunk(data=["simple text"], metadata={})]), + (b"bytes text", [DataChunk(data=["bytes text"], metadata={})]), ], ) def test_extract_text(input_content, expected): @@ -60,7 +61,7 @@ def test_extract_text_with_path(tmp_path): test_file.write_text("file text content") assert extract_text(test_file) == [ - {"contents": ["file text content"], "metadata": {}} + DataChunk(data=["file text content"], metadata={}) ] @@ -72,8 +73,8 @@ def test_doc_to_images(): for page, pdf_page in zip(result, pdf.pages()): pix = pdf_page.get_pixmap() img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) - assert page["contents"] == [img] - assert page["metadata"] == { + assert page.data == [img] + assert page.metadata == { "page": pdf_page.number, "width": pdf_page.rect.width, "height": pdf_page.rect.height, @@ -86,8 +87,8 @@ def test_extract_image_with_path(tmp_path): img.save(img_path) (page,) = extract_image(img_path) - assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore - assert page["metadata"] == {} + assert page.data[0].tobytes() == img.convert("RGB").tobytes() # type: ignore + assert page.metadata == {} def test_extract_image_with_bytes(): @@ -97,8 +98,8 @@ def test_extract_image_with_bytes(): img_bytes = buffer.getvalue() (page,) = extract_image(img_bytes) - assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore - assert page["metadata"] == {} + assert page.data[0].tobytes() == img.convert("RGB").tobytes() # type: ignore + assert page.metadata == {} def test_extract_image_with_str():