mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-08 13:24:41 +02:00
simplify embedding
This commit is contained in:
parent
9f1632555b
commit
a5618f3543
@ -8,6 +8,7 @@ import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar, Iterable, Sequence, cast
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
@ -88,6 +89,19 @@ def clean_filename(filename: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", filename).strip("_")
|
||||
|
||||
|
||||
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
|
||||
for i, image in enumerate(images):
|
||||
if not image.filename: # type: ignore
|
||||
filename = f"{chunk_id}_{i}.{image.format}" # type: ignore
|
||||
image.save(filename)
|
||||
|
||||
return [image.filename for image in images] # type: ignore
|
||||
|
||||
|
||||
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
|
||||
return [chunk] + [i for i in images if i.filename in chunk] # type: ignore
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
"""Stores content chunks with their vector embeddings."""
|
||||
|
||||
@ -107,6 +121,7 @@ class Chunk(Base):
|
||||
checked_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
vector: ClassVar[list[float] | None] = None
|
||||
item_metadata: ClassVar[dict[str, Any] | None] = None
|
||||
images: list[Image.Image] = []
|
||||
|
||||
# One of file_path or content must be populated
|
||||
__table_args__ = (
|
||||
@ -119,11 +134,8 @@ class Chunk(Base):
|
||||
if self.file_path is None:
|
||||
return [cast(str, self.content)]
|
||||
|
||||
path = pathlib.Path(self.file_path.replace("/app/", ""))
|
||||
if cast(str, self.file_path).endswith("*"):
|
||||
files = list(path.parent.glob(path.name))
|
||||
else:
|
||||
files = [path]
|
||||
paths = [pathlib.Path(p) for p in self.file_path.split("\n")]
|
||||
files = [path for path in paths if path.exists()]
|
||||
|
||||
items = []
|
||||
for file_path in files:
|
||||
@ -179,13 +191,37 @@ class SourceItem(Base):
|
||||
"""Get vector IDs from associated chunks."""
|
||||
return [chunk.id for chunk in self.chunks]
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
return [
|
||||
extract.DataChunk(
|
||||
data=cast(str, self.content),
|
||||
collection=cast(str, self.modality),
|
||||
)
|
||||
]
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
chunks: list[list[extract.MulitmodalChunk]] = []
|
||||
if cast(str | None, self.content):
|
||||
chunks = [[c] for c in chunker.chunk_text(cast(str, self.content))]
|
||||
|
||||
mime_type = cast(str | None, self.mime_type)
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
chunks.append([Image.open(self.filename)])
|
||||
return chunks
|
||||
|
||||
def _make_chunk(
|
||||
self, data: Sequence[extract.MulitmodalChunk], metadata: dict[str, Any] = {}
|
||||
):
|
||||
chunk_id = str(uuid.uuid4())
|
||||
text = "\n\n".join(c for c in data if isinstance(c, str))
|
||||
images = [c for c in data if isinstance(c, Image.Image)]
|
||||
image_names = image_filenames(chunk_id, images)
|
||||
|
||||
chunk = Chunk(
|
||||
id=chunk_id,
|
||||
source=self,
|
||||
content=text,
|
||||
images=images,
|
||||
file_path="\n".join(image_names) if image_names else None,
|
||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||
item_metadata=self.as_payload() | metadata,
|
||||
)
|
||||
return chunk
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
@ -319,18 +355,14 @@ class EmailAttachment(SourceItem):
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
if cast(str | None, self.filename):
|
||||
contents = pathlib.Path(cast(str, self.filename)).read_bytes()
|
||||
else:
|
||||
contents = cast(str, self.content)
|
||||
|
||||
return extract.extract_data_chunks(
|
||||
cast(str, self.mime_type),
|
||||
contents,
|
||||
collection=cast(str, self.modality),
|
||||
embedding_model=collections.collection_model(cast(str, self.modality)),
|
||||
)
|
||||
chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents)
|
||||
return [self._make_chunk(c.data, metadata | c.metadata) for c in chunks]
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
||||
@ -433,14 +465,10 @@ class Comic(SourceItem):
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
image = Image.open(pathlib.Path(cast(str, self.filename)))
|
||||
return [
|
||||
extract.DataChunk(
|
||||
data=[image, cast(str, self.title), cast(str, self.author)],
|
||||
collection=cast(str, self.modality),
|
||||
)
|
||||
]
|
||||
description = f"{self.title} by {self.author}"
|
||||
return [[image, description]]
|
||||
|
||||
|
||||
class Book(Base):
|
||||
@ -538,16 +566,11 @@ class BookSection(SourceItem):
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)]
|
||||
texts += [(cast(str, self.content), self.start_page)]
|
||||
return [
|
||||
extract.DataChunk(
|
||||
data=[text],
|
||||
collection=cast(str, self.modality),
|
||||
metadata={"page": page_number},
|
||||
max_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||
)
|
||||
self._make_chunk([text.strip()], metadata | {"page": page_number})
|
||||
for text, page_number in texts
|
||||
]
|
||||
|
||||
@ -601,29 +624,18 @@ class BlogPost(SourceItem):
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v}
|
||||
|
||||
def data_chunks(self) -> Iterable[extract.DataChunk]:
|
||||
def _chunk_contents(self) -> Sequence[Sequence[extract.MulitmodalChunk]]:
|
||||
images = [Image.open(image) for image in self.images]
|
||||
data = [cast(str, self.content)] + images
|
||||
|
||||
# Always embed the full content as a single chunk (if possible, of course)
|
||||
chunks = [
|
||||
extract.DataChunk(
|
||||
data=data,
|
||||
collection=cast(str, self.modality),
|
||||
max_size=chunker.EMBEDDING_MAX_TOKENS,
|
||||
)
|
||||
]
|
||||
content = cast(str, self.content)
|
||||
full_text = [content, *images]
|
||||
|
||||
# If the content is long enough, also embed it as chunks of the default size.
|
||||
tokens = chunker.approx_token_count(cast(str, self.content))
|
||||
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
chunks += [
|
||||
extract.DataChunk(
|
||||
data=data,
|
||||
collection=cast(str, self.modality),
|
||||
)
|
||||
]
|
||||
return chunks
|
||||
if tokens < chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
return [full_text]
|
||||
|
||||
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
|
||||
return [full_text] + chunks
|
||||
|
||||
|
||||
class MiscDoc(SourceItem):
|
||||
|
@ -1,17 +1,16 @@
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import pathlib
|
||||
import uuid
|
||||
from typing import Any, Iterable, Literal, cast
|
||||
from typing import Iterable, Literal, cast
|
||||
|
||||
import voyageai
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import extract, settings
|
||||
from memory.common.chunker import chunk_text, DEFAULT_CHUNK_TOKENS, OVERLAP_TOKENS
|
||||
from memory.common.collections import ALL_COLLECTIONS, Vector
|
||||
from memory.common.chunker import (
|
||||
DEFAULT_CHUNK_TOKENS,
|
||||
OVERLAP_TOKENS,
|
||||
chunk_text,
|
||||
)
|
||||
from memory.common.collections import Vector
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common.extract import DataChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -72,91 +71,18 @@ def embed_mixed(
|
||||
return embed_chunks([chunks], model, input_type)
|
||||
|
||||
|
||||
def write_to_file(chunk_id: str, item: extract.MulitmodalChunk) -> pathlib.Path:
|
||||
if isinstance(item, str):
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt"
|
||||
filename.write_text(item)
|
||||
elif isinstance(item, bytes):
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
|
||||
filename.write_bytes(item)
|
||||
elif isinstance(item, Image.Image):
|
||||
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
|
||||
item.save(filename)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(item)}")
|
||||
return filename
|
||||
|
||||
|
||||
def make_chunk(
|
||||
contents: Sequence[extract.MulitmodalChunk],
|
||||
vector: Vector,
|
||||
metadata: dict[str, Any] = {},
|
||||
) -> Chunk:
|
||||
"""Create a Chunk object from a page and a vector.
|
||||
|
||||
This is quite complex, because we need to handle the case where the page is a single string,
|
||||
a single image, or a list of strings and images.
|
||||
"""
|
||||
chunk_id = str(uuid.uuid4())
|
||||
content, filename = None, None
|
||||
if all(isinstance(c, str) for c in contents):
|
||||
content = "\n\n".join(cast(list[str], contents))
|
||||
model = settings.TEXT_EMBEDDING_MODEL
|
||||
elif len(contents) == 1:
|
||||
filename = write_to_file(chunk_id, contents[0]).absolute().as_posix()
|
||||
model = settings.MIXED_EMBEDDING_MODEL
|
||||
else:
|
||||
for i, item in enumerate(contents):
|
||||
write_to_file(f"{chunk_id}_{i}", item)
|
||||
model = settings.MIXED_EMBEDDING_MODEL
|
||||
filename = (settings.CHUNK_STORAGE_DIR / f"{chunk_id}_*").absolute().as_posix()
|
||||
|
||||
return Chunk(
|
||||
id=chunk_id,
|
||||
file_path=filename,
|
||||
content=content,
|
||||
embedding_model=model,
|
||||
vector=vector,
|
||||
item_metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def embed_data_chunk(
|
||||
chunk: DataChunk,
|
||||
metadata: dict[str, Any] = {},
|
||||
chunk_size: int | None = None,
|
||||
) -> list[Chunk]:
|
||||
chunk_size = chunk.max_size or chunk_size or DEFAULT_CHUNK_TOKENS
|
||||
|
||||
model = chunk.embedding_model
|
||||
if not model and chunk.collection:
|
||||
model = ALL_COLLECTIONS.get(chunk.collection, {}).get("model")
|
||||
if not model:
|
||||
model = settings.TEXT_EMBEDDING_MODEL
|
||||
|
||||
def embed_chunk(chunk: Chunk) -> Chunk:
|
||||
model = cast(str, chunk.embedding_model)
|
||||
if model == settings.TEXT_EMBEDDING_MODEL:
|
||||
vectors = embed_text(cast(list[str], chunk.data), chunk_size=chunk_size)
|
||||
content = cast(str, chunk.content)
|
||||
elif model == settings.MIXED_EMBEDDING_MODEL:
|
||||
vectors = embed_mixed(
|
||||
cast(list[extract.MulitmodalChunk], chunk.data),
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
content = [cast(str, chunk.content)] + chunk.images
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
metadata = metadata | chunk.metadata
|
||||
return [make_chunk(chunk.data, vector, metadata) for vector in vectors]
|
||||
raise ValueError(f"Unsupported model: {chunk.embedding_model}")
|
||||
vectors = embed_chunks([content], model) # type: ignore
|
||||
chunk.vector = vectors[0] # type: ignore
|
||||
return chunk
|
||||
|
||||
|
||||
def embed_source_item(
|
||||
item: SourceItem,
|
||||
metadata: dict[str, Any] = {},
|
||||
chunk_size: int | None = None,
|
||||
) -> list[Chunk]:
|
||||
return [
|
||||
chunk
|
||||
for data_chunk in item.data_chunks()
|
||||
for chunk in embed_data_chunk(
|
||||
data_chunk, item.as_payload() | metadata, chunk_size
|
||||
)
|
||||
]
|
||||
def embed_source_item(item: SourceItem) -> list[Chunk]:
|
||||
return [embed_chunk(chunk) for chunk in item.data_chunks()]
|
||||
|
@ -4,8 +4,9 @@ import logging
|
||||
import pathlib
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, NotRequired, Sequence, TypedDict, cast
|
||||
from typing import Any, Generator, Sequence, cast
|
||||
|
||||
from memory.common import chunker
|
||||
import pymupdf # PyMuPDF
|
||||
import pypandoc
|
||||
from PIL import Image
|
||||
@ -15,19 +16,9 @@ logger = logging.getLogger(__name__)
|
||||
MulitmodalChunk = Image.Image | str
|
||||
|
||||
|
||||
class Page(TypedDict):
|
||||
contents: Sequence[MulitmodalChunk]
|
||||
metadata: dict[str, Any]
|
||||
# This is used to override the default chunk size for the page
|
||||
chunk_size: NotRequired[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataChunk:
|
||||
data: Sequence[MulitmodalChunk]
|
||||
collection: str | None = None
|
||||
embedding_model: str | None = None
|
||||
max_size: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@ -48,18 +39,18 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
|
||||
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
|
||||
def doc_to_images(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
with as_file(content) as file_path:
|
||||
with pymupdf.open(file_path) as pdf:
|
||||
return [
|
||||
{
|
||||
"contents": [page_to_image(page)],
|
||||
"metadata": {
|
||||
DataChunk(
|
||||
data=[page_to_image(page)],
|
||||
metadata={
|
||||
"page": page.number,
|
||||
"width": page.rect.width,
|
||||
"height": page.rect.height,
|
||||
},
|
||||
}
|
||||
)
|
||||
for page in pdf.pages()
|
||||
]
|
||||
|
||||
@ -93,7 +84,7 @@ def docx_to_pdf(
|
||||
raise
|
||||
|
||||
|
||||
def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]:
|
||||
def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[DataChunk]:
|
||||
"""Extract content from DOCX by converting to PDF first, then processing"""
|
||||
with as_file(docx_path) as file_path:
|
||||
pdf_path = docx_to_pdf(file_path)
|
||||
@ -101,57 +92,47 @@ def extract_docx(docx_path: pathlib.Path | bytes | str) -> list[Page]:
|
||||
return doc_to_images(pdf_path)
|
||||
|
||||
|
||||
def extract_image(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
def extract_image(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
image = Image.open(content)
|
||||
elif isinstance(content, bytes):
|
||||
image = Image.open(io.BytesIO(content))
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
return [{"contents": [image], "metadata": {}}]
|
||||
return [DataChunk(data=[image])]
|
||||
|
||||
|
||||
def extract_text(content: bytes | str | pathlib.Path) -> list[Page]:
|
||||
def extract_text(
|
||||
content: bytes | str | pathlib.Path, chunk_size: int | None = None
|
||||
) -> list[DataChunk]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
content = content.read_text()
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return [{"contents": [cast(str, content)], "metadata": {}}]
|
||||
content = cast(str, content)
|
||||
chunks = chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
|
||||
return [DataChunk(data=[c]) for c in chunks]
|
||||
|
||||
|
||||
def extract_data_chunks(
|
||||
mime_type: str,
|
||||
content: bytes | str | pathlib.Path,
|
||||
collection: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
chunk_size: int | None = None,
|
||||
) -> list[DataChunk]:
|
||||
pages = []
|
||||
chunks = []
|
||||
logger.info(f"Extracting content from {mime_type}")
|
||||
if mime_type == "application/pdf":
|
||||
pages = doc_to_images(content)
|
||||
chunks = doc_to_images(content)
|
||||
elif mime_type in [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/msword",
|
||||
]:
|
||||
logger.info(f"Extracting content from {content}")
|
||||
pages = extract_docx(content)
|
||||
logger.info(f"Extracted {len(pages)} pages from {content}")
|
||||
chunks = extract_docx(content)
|
||||
logger.info(f"Extracted {len(chunks)} pages from {content}")
|
||||
elif mime_type.startswith("text/"):
|
||||
pages = extract_text(content)
|
||||
chunks = extract_text(content, chunk_size)
|
||||
elif mime_type.startswith("image/"):
|
||||
pages = extract_image(content)
|
||||
|
||||
if chunk_size:
|
||||
pages: list[Page] = [{**page, "chunk_size": chunk_size} for page in pages]
|
||||
|
||||
return [
|
||||
DataChunk(
|
||||
data=page["contents"],
|
||||
collection=collection,
|
||||
embedding_model=embedding_model,
|
||||
max_size=chunk_size,
|
||||
)
|
||||
for page in pages
|
||||
]
|
||||
chunks = extract_image(content)
|
||||
return chunks
|
||||
|
@ -1,16 +1,10 @@
|
||||
import pathlib
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import settings, collections
|
||||
from memory.common import collections
|
||||
from memory.common.embedding import (
|
||||
embed_mixed,
|
||||
embed_text,
|
||||
make_chunk,
|
||||
write_to_file,
|
||||
)
|
||||
|
||||
|
||||
@ -57,302 +51,3 @@ def test_embed_text(mock_embed):
|
||||
def test_embed_mixed(mock_embed):
|
||||
items = ["text", {"type": "image", "data": "base64"}]
|
||||
assert embed_mixed(items) == [[0]]
|
||||
|
||||
|
||||
def test_write_to_file_text(mock_file_storage):
|
||||
"""Test writing a string to a file."""
|
||||
chunk_id = "test-chunk-id"
|
||||
content = "This is a test string"
|
||||
|
||||
file_path = write_to_file(chunk_id, content)
|
||||
|
||||
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.txt"
|
||||
assert file_path.exists()
|
||||
assert file_path.read_text() == content
|
||||
|
||||
|
||||
def test_write_to_file_bytes(mock_file_storage):
|
||||
"""Test writing bytes to a file."""
|
||||
chunk_id = "test-chunk-id"
|
||||
content = b"These are test bytes"
|
||||
|
||||
file_path = write_to_file(chunk_id, content) # type: ignore
|
||||
|
||||
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.bin"
|
||||
assert file_path.exists()
|
||||
assert file_path.read_bytes() == content
|
||||
|
||||
|
||||
def test_write_to_file_image(mock_file_storage):
|
||||
"""Test writing an image to a file."""
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
chunk_id = "test-chunk-id"
|
||||
|
||||
file_path = write_to_file(chunk_id, img) # type: ignore
|
||||
|
||||
assert file_path == settings.CHUNK_STORAGE_DIR / f"{chunk_id}.png"
|
||||
assert file_path.exists()
|
||||
# Verify it's a valid image file by opening it
|
||||
image = Image.open(file_path)
|
||||
assert image.size == (100, 100)
|
||||
|
||||
|
||||
def test_write_to_file_unsupported_type(mock_file_storage):
|
||||
"""Test that an error is raised for unsupported content types."""
|
||||
chunk_id = "test-chunk-id"
|
||||
content = 123 # Integer is not a supported type
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported content type"):
|
||||
write_to_file(chunk_id, content) # type: ignore
|
||||
|
||||
|
||||
def test_make_chunk_text_only(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from string content."""
|
||||
contents = ["text content 1", "text content 2"]
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
):
|
||||
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000001"
|
||||
assert cast(str, chunk.content) == "text content 1\n\ntext content 2"
|
||||
assert chunk.file_path is None
|
||||
assert cast(str, chunk.embedding_model) == settings.TEXT_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
|
||||
def test_make_chunk_single_image(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from a single image."""
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
contents = [img]
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000002")
|
||||
):
|
||||
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000002"
|
||||
assert chunk.content is None
|
||||
assert cast(str, chunk.file_path) == str(
|
||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000002.png",
|
||||
)
|
||||
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
# Verify the file exists
|
||||
assert pathlib.Path(cast(str, chunk.file_path)).exists()
|
||||
|
||||
|
||||
def test_make_chunk_mixed_content(mock_file_storage, db_session):
|
||||
"""Test creating a chunk from mixed content (string and image)."""
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
contents = ["text content", img]
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
metadata = {"doc_type": "test", "source": "unit-test"}
|
||||
|
||||
with patch.object(
|
||||
uuid, "uuid4", return_value=uuid.UUID("00000000-0000-0000-0000-000000000003")
|
||||
):
|
||||
chunk = make_chunk(contents, vector, metadata) # type: ignore
|
||||
|
||||
assert cast(str, chunk.id) == "00000000-0000-0000-0000-000000000003"
|
||||
assert chunk.content is None
|
||||
assert cast(str, chunk.file_path) == str(
|
||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_*",
|
||||
)
|
||||
assert cast(str, chunk.embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||
assert chunk.vector == vector
|
||||
assert chunk.item_metadata == metadata
|
||||
|
||||
# Verify the files exist
|
||||
assert (
|
||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_0.txt"
|
||||
).exists()
|
||||
assert (
|
||||
settings.CHUNK_STORAGE_DIR / "00000000-0000-0000-0000-000000000003_1.png"
|
||||
).exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data,embedding_model,collection,expected_model,expected_count,expected_has_content",
|
||||
[
|
||||
# Text-only with default model
|
||||
(
|
||||
["text content 1", "text content 2"],
|
||||
None,
|
||||
None,
|
||||
settings.TEXT_EMBEDDING_MODEL,
|
||||
2,
|
||||
True,
|
||||
),
|
||||
# Text with explicit mixed model - but make_chunk still uses TEXT_EMBEDDING_MODEL for text-only content
|
||||
(
|
||||
["text content"],
|
||||
settings.MIXED_EMBEDDING_MODEL,
|
||||
None,
|
||||
settings.TEXT_EMBEDDING_MODEL,
|
||||
1,
|
||||
True,
|
||||
),
|
||||
# Text collection model selection - make_chunk uses TEXT_EMBEDDING_MODEL for text-only content
|
||||
(["text content"], None, "mail", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||
(["text content"], None, "photo", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||
(["text content"], None, "doc", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||
# Unknown collection falls back to default
|
||||
(["text content"], None, "unknown", settings.TEXT_EMBEDDING_MODEL, 1, True),
|
||||
# Explicit model takes precedence over collection
|
||||
(
|
||||
["text content"],
|
||||
settings.TEXT_EMBEDDING_MODEL,
|
||||
"photo",
|
||||
settings.TEXT_EMBEDDING_MODEL,
|
||||
1,
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_embed_data_chunk_scenarios(
|
||||
data,
|
||||
embedding_model,
|
||||
collection,
|
||||
expected_model,
|
||||
expected_count,
|
||||
expected_has_content,
|
||||
mock_embed,
|
||||
mock_file_storage,
|
||||
):
|
||||
"""Test various embedding scenarios for data chunks."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(
|
||||
data=data,
|
||||
embedding_model=embedding_model,
|
||||
collection=collection,
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
|
||||
result = embed_data_chunk(chunk, {"doc_type": "test"})
|
||||
|
||||
assert len(result) == expected_count
|
||||
assert all(cast(str, c.embedding_model) == expected_model for c in result)
|
||||
if expected_has_content:
|
||||
assert all(c.content is not None for c in result)
|
||||
assert all(c.file_path is None for c in result)
|
||||
else:
|
||||
assert all(c.content is None for c in result)
|
||||
assert all(c.file_path is not None for c in result)
|
||||
assert all(
|
||||
c.item_metadata == {"source": "test", "doc_type": "test"} for c in result
|
||||
)
|
||||
|
||||
|
||||
def test_embed_data_chunk_mixed_content(mock_embed, mock_file_storage):
|
||||
"""Test embedding mixed content (text and images)."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
img = Image.new("RGB", (100, 100), color=(73, 109, 137))
|
||||
chunk = DataChunk(
|
||||
data=["text content", img],
|
||||
embedding_model=settings.MIXED_EMBEDDING_MODEL,
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
|
||||
result = embed_data_chunk(chunk)
|
||||
|
||||
assert len(result) == 1 # Mixed content returns single vector
|
||||
assert result[0].content is None # Mixed content stored in files
|
||||
assert result[0].file_path is not None
|
||||
assert cast(str, result[0].embedding_model) == settings.MIXED_EMBEDDING_MODEL
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_max_size,chunk_size_param,expected_chunk_size",
|
||||
[
|
||||
(512, 1024, 512), # chunk.max_size takes precedence
|
||||
(None, 2048, 2048), # chunk_size parameter used when max_size is None
|
||||
(256, None, 256), # chunk.max_size used when parameter is None
|
||||
],
|
||||
)
|
||||
def test_embed_data_chunk_chunk_size_handling(
|
||||
chunk_max_size, chunk_size_param, expected_chunk_size, mock_embed, mock_file_storage
|
||||
):
|
||||
"""Test chunk size parameter handling."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(
|
||||
data=["text content"], max_size=chunk_max_size, metadata={"source": "test"}
|
||||
)
|
||||
|
||||
with patch("memory.common.embedding.embed_text") as mock_embed_text:
|
||||
mock_embed_text.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
result = embed_data_chunk(chunk, chunk_size=chunk_size_param)
|
||||
|
||||
mock_embed_text.assert_called_once()
|
||||
args, kwargs = mock_embed_text.call_args
|
||||
assert kwargs["chunk_size"] == expected_chunk_size
|
||||
|
||||
|
||||
def test_embed_data_chunk_metadata_merging(mock_embed, mock_file_storage):
|
||||
"""Test that chunk metadata and parameter metadata are properly merged."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(
|
||||
data=["text content"], metadata={"source": "test", "type": "chunk"}
|
||||
)
|
||||
metadata = {
|
||||
"doc_type": "test",
|
||||
"source": "override",
|
||||
} # chunk.metadata takes precedence over parameter metadata
|
||||
|
||||
result = embed_data_chunk(chunk, metadata)
|
||||
|
||||
assert len(result) == 1
|
||||
expected_metadata = {
|
||||
"source": "test",
|
||||
"type": "chunk",
|
||||
"doc_type": "test",
|
||||
} # chunk source wins
|
||||
assert result[0].item_metadata == expected_metadata
|
||||
|
||||
|
||||
def test_embed_data_chunk_unsupported_model(mock_embed, mock_file_storage):
|
||||
"""Test error handling for unsupported embedding model."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(
|
||||
data=["text content"],
|
||||
embedding_model="unsupported-model",
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model: unsupported-model"):
|
||||
embed_data_chunk(chunk)
|
||||
|
||||
|
||||
def test_embed_data_chunk_empty_data(mock_embed, mock_file_storage):
|
||||
"""Test handling of empty data."""
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.embedding import embed_data_chunk
|
||||
|
||||
chunk = DataChunk(data=[], metadata={"source": "test"})
|
||||
|
||||
# Should handle empty data gracefully
|
||||
with patch("memory.common.embedding.embed_text") as mock_embed_text:
|
||||
mock_embed_text.return_value = []
|
||||
|
||||
result = embed_data_chunk(chunk)
|
||||
|
||||
assert result == []
|
||||
|
@ -11,6 +11,7 @@ from memory.common.extract import (
|
||||
extract_image,
|
||||
docx_to_pdf,
|
||||
extract_docx,
|
||||
DataChunk,
|
||||
)
|
||||
|
||||
|
||||
@ -47,8 +48,8 @@ def test_as_file_with_str():
|
||||
@pytest.mark.parametrize(
|
||||
"input_content,expected",
|
||||
[
|
||||
("simple text", [{"contents": ["simple text"], "metadata": {}}]),
|
||||
(b"bytes text", [{"contents": ["bytes text"], "metadata": {}}]),
|
||||
("simple text", [DataChunk(data=["simple text"], metadata={})]),
|
||||
(b"bytes text", [DataChunk(data=["bytes text"], metadata={})]),
|
||||
],
|
||||
)
|
||||
def test_extract_text(input_content, expected):
|
||||
@ -60,7 +61,7 @@ def test_extract_text_with_path(tmp_path):
|
||||
test_file.write_text("file text content")
|
||||
|
||||
assert extract_text(test_file) == [
|
||||
{"contents": ["file text content"], "metadata": {}}
|
||||
DataChunk(data=["file text content"], metadata={})
|
||||
]
|
||||
|
||||
|
||||
@ -72,8 +73,8 @@ def test_doc_to_images():
|
||||
for page, pdf_page in zip(result, pdf.pages()):
|
||||
pix = pdf_page.get_pixmap()
|
||||
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
assert page["contents"] == [img]
|
||||
assert page["metadata"] == {
|
||||
assert page.data == [img]
|
||||
assert page.metadata == {
|
||||
"page": pdf_page.number,
|
||||
"width": pdf_page.rect.width,
|
||||
"height": pdf_page.rect.height,
|
||||
@ -86,8 +87,8 @@ def test_extract_image_with_path(tmp_path):
|
||||
img.save(img_path)
|
||||
|
||||
(page,) = extract_image(img_path)
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
|
||||
assert page["metadata"] == {}
|
||||
assert page.data[0].tobytes() == img.convert("RGB").tobytes() # type: ignore
|
||||
assert page.metadata == {}
|
||||
|
||||
|
||||
def test_extract_image_with_bytes():
|
||||
@ -97,8 +98,8 @@ def test_extract_image_with_bytes():
|
||||
img_bytes = buffer.getvalue()
|
||||
|
||||
(page,) = extract_image(img_bytes)
|
||||
assert page["contents"][0].tobytes() == img.convert("RGB").tobytes() # type: ignore
|
||||
assert page["metadata"] == {}
|
||||
assert page.data[0].tobytes() == img.convert("RGB").tobytes() # type: ignore
|
||||
assert page.metadata == {}
|
||||
|
||||
|
||||
def test_extract_image_with_str():
|
||||
|
Loading…
x
Reference in New Issue
Block a user