tests for models

This commit is contained in:
Daniel O'Connell 2025-05-26 12:49:11 +02:00
parent a5618f3543
commit 482aefabe3
8 changed files with 1009 additions and 232 deletions

View File

@ -1,39 +0,0 @@
"""Update field for chunks
Revision ID: d292d48ec74e
Revises: 4684845ca51e
Create Date: 2025-05-04 12:56:53.231393
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "d292d48ec74e"
down_revision: Union[str, None] = "4684845ca51e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"chunk",
sa.Column(
"checked_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
)
op.drop_column("misc_doc", "mime_type")
def downgrade() -> None:
op.add_column(
"misc_doc",
sa.Column("mime_type", sa.TEXT(), autoincrement=False, nullable=True),
)
op.drop_column("chunk", "checked_at")

View File

@ -1,43 +0,0 @@
"""Add comics
Revision ID: b78b1fff9974
Revises: d292d48ec74e
Create Date: 2025-05-04 23:45:52.733301
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b78b1fff9974'
down_revision: Union[str, None] = 'd292d48ec74e'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('comic',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('author', sa.Text(), nullable=True),
sa.Column('published', sa.DateTime(timezone=True), nullable=True),
sa.Column('volume', sa.Text(), nullable=True),
sa.Column('issue', sa.Text(), nullable=True),
sa.Column('page', sa.Integer(), nullable=True),
sa.Column('url', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['id'], ['source_item.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('comic_author_idx', 'comic', ['author'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('comic_author_idx', table_name='comic')
op.drop_table('comic')
# ### end Alembic commands ###

View File

@ -1,107 +0,0 @@
"""Add ebooks
Revision ID: fe570eab952a
Revises: b78b1fff9974
Create Date: 2025-05-23 16:37:53.354723
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "fe570eab952a"
down_revision: Union[str, None] = "b78b1fff9974"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"book",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("isbn", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=False),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("publisher", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.Column("language", sa.Text(), nullable=True),
sa.Column("edition", sa.Text(), nullable=True),
sa.Column("series", sa.Text(), nullable=True),
sa.Column("series_number", sa.Integer(), nullable=True),
sa.Column("total_pages", sa.Integer(), nullable=True),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("tags", sa.ARRAY(sa.Text()), nullable=False, server_default="{}"),
sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("isbn"),
)
op.create_index("book_author_idx", "book", ["author"], unique=False)
op.create_index("book_isbn_idx", "book", ["isbn"], unique=False)
op.create_index("book_title_idx", "book", ["title"], unique=False)
op.create_table(
"book_section",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("book_id", sa.BigInteger(), nullable=False),
sa.Column("section_title", sa.Text(), nullable=True),
sa.Column("section_number", sa.Integer(), nullable=True),
sa.Column("section_level", sa.Integer(), nullable=True),
sa.Column("start_page", sa.Integer(), nullable=True),
sa.Column("end_page", sa.Integer(), nullable=True),
sa.Column("parent_section_id", sa.BigInteger(), nullable=True),
sa.ForeignKeyConstraint(["book_id"], ["book.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["parent_section_id"],
["book_section.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("book_section_book_idx", "book_section", ["book_id"], unique=False)
op.create_index(
"book_section_level_idx",
"book_section",
["section_level", "section_number"],
unique=False,
)
op.create_index(
"book_section_parent_idx", "book_section", ["parent_section_id"], unique=False
)
op.drop_table("book_doc")
def downgrade() -> None:
op.create_table(
"book_doc",
sa.Column("id", sa.BIGINT(), autoincrement=False, nullable=False),
sa.Column("title", sa.TEXT(), autoincrement=False, nullable=True),
sa.Column("author", sa.TEXT(), autoincrement=False, nullable=True),
sa.Column("chapter", sa.TEXT(), autoincrement=False, nullable=True),
sa.Column(
"published",
postgresql.TIMESTAMP(timezone=True),
autoincrement=False,
nullable=True,
),
sa.ForeignKeyConstraint(
["id"], ["source_item.id"], name="book_doc_id_fkey", ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id", name="book_doc_pkey"),
)
op.drop_index("book_section_parent_idx", table_name="book_section")
op.drop_index("book_section_level_idx", table_name="book_section")
op.drop_index("book_section_book_idx", table_name="book_section")
op.drop_table("book_section")
op.drop_index("book_title_idx", table_name="book")
op.drop_index("book_isbn_idx", table_name="book")
op.drop_index("book_author_idx", table_name="book")
op.drop_table("book")

View File

@ -1,8 +1,8 @@
"""Initial structure for the database.
"""Initial structure
Revision ID: 4684845ca51e
Revises: a466a07360d5
Create Date: 2025-05-03 14:00:56.113840
Revision ID: d897c6353a84
Revises:
Create Date: 2025-05-26 10:55:17.311208
"""
@ -13,7 +13,7 @@ import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "4684845ca51e"
revision: str = "d897c6353a84"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@ -22,6 +22,35 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
op.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"')
op.create_table(
"book",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("isbn", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=False),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("publisher", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.Column("language", sa.Text(), nullable=True),
sa.Column("edition", sa.Text(), nullable=True),
sa.Column("series", sa.Text(), nullable=True),
sa.Column("series_number", sa.Integer(), nullable=True),
sa.Column("total_pages", sa.Integer(), nullable=True),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("tags", sa.ARRAY(sa.Text()), server_default="{}", nullable=False),
sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("isbn"),
)
op.create_index("book_author_idx", "book", ["author"], unique=False)
op.create_index("book_isbn_idx", "book", ["isbn"], unique=False)
op.create_index("book_title_idx", "book", ["title"], unique=False)
op.create_table(
"email_accounts",
sa.Column("id", sa.BigInteger(), nullable=False),
@ -82,6 +111,12 @@ def upgrade() -> None:
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("url"),
)
@ -128,21 +163,53 @@ def upgrade() -> None:
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("url", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("domain", sa.Text(), nullable=True),
sa.Column("word_count", sa.Integer(), nullable=True),
sa.Column("images", sa.ARRAY(sa.Text()), nullable=True),
sa.Column(
"webpage_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("url"),
)
op.create_index("blog_post_author_idx", "blog_post", ["author"], unique=False)
op.create_index("blog_post_domain_idx", "blog_post", ["domain"], unique=False)
op.create_index("blog_post_published_idx", "blog_post", ["published"], unique=False)
op.create_index(
"blog_post_word_count_idx", "blog_post", ["word_count"], unique=False
)
op.create_table(
"book_doc",
"book_section",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("chapter", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.Column("book_id", sa.BigInteger(), nullable=False),
sa.Column("section_title", sa.Text(), nullable=True),
sa.Column("section_number", sa.Integer(), nullable=True),
sa.Column("section_level", sa.Integer(), nullable=True),
sa.Column("start_page", sa.Integer(), nullable=True),
sa.Column("end_page", sa.Integer(), nullable=True),
sa.Column("parent_section_id", sa.BigInteger(), nullable=True),
sa.ForeignKeyConstraint(["book_id"], ["book.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["parent_section_id"],
["book_section.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("book_section_book_idx", "book_section", ["book_id"], unique=False)
op.create_index(
"book_section_level_idx",
"book_section",
["section_level", "section_number"],
unique=False,
)
op.create_index(
"book_section_parent_idx", "book_section", ["parent_section_id"], unique=False
)
op.create_table(
"chat_message",
sa.Column("id", sa.BigInteger(), nullable=False),
@ -165,7 +232,7 @@ def upgrade() -> None:
nullable=False,
),
sa.Column("source_id", sa.BigInteger(), nullable=False),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("file_paths", sa.ARRAY(sa.Text()), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("embedding_model", sa.Text(), nullable=True),
sa.Column(
@ -174,11 +241,31 @@ def upgrade() -> None:
server_default=sa.text("now()"),
nullable=True,
),
sa.CheckConstraint("(file_path IS NOT NULL) OR (content IS NOT NULL)"),
sa.Column(
"checked_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
sa.ForeignKeyConstraint(["source_id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("chunk_source_idx", "chunk", ["source_id"], unique=False)
op.create_table(
"comic",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("author", sa.Text(), nullable=True),
sa.Column("published", sa.DateTime(timezone=True), nullable=True),
sa.Column("volume", sa.Text(), nullable=True),
sa.Column("issue", sa.Text(), nullable=True),
sa.Column("page", sa.Integer(), nullable=True),
sa.Column("url", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("comic_author_idx", "comic", ["author"], unique=False)
op.create_table(
"git_commit",
sa.Column("id", sa.BigInteger(), nullable=False),
@ -263,7 +350,6 @@ def upgrade() -> None:
"misc_doc",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("path", sa.Text(), nullable=True),
sa.Column("mime_type", sa.TEXT(), autoincrement=False, nullable=True),
sa.ForeignKeyConstraint(["id"], ["source_item.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
@ -321,15 +407,25 @@ def downgrade() -> None:
op.drop_index("git_files_idx", table_name="git_commit", postgresql_using="gin")
op.drop_index("git_date_idx", table_name="git_commit")
op.drop_table("git_commit")
op.drop_index("comic_author_idx", table_name="comic")
op.drop_table("comic")
op.drop_index("chunk_source_idx", table_name="chunk")
op.drop_table("chunk")
op.drop_index("chat_channel_idx", table_name="chat_message")
op.drop_table("chat_message")
op.drop_table("book_doc")
op.drop_index("book_section_parent_idx", table_name="book_section")
op.drop_index("book_section_level_idx", table_name="book_section")
op.drop_index("book_section_book_idx", table_name="book_section")
op.drop_table("book_section")
op.drop_index("blog_post_word_count_idx", table_name="blog_post")
op.drop_index("blog_post_published_idx", table_name="blog_post")
op.drop_index("blog_post_domain_idx", table_name="blog_post")
op.drop_index("blog_post_author_idx", table_name="blog_post")
op.drop_table("blog_post")
op.drop_index("source_tags_idx", table_name="source_item", postgresql_using="gin")
op.drop_index("source_status_idx", table_name="source_item")
op.drop_index("source_modality_idx", table_name="source_item")
op.drop_index("source_filename_idx", table_name="source_item")
op.drop_table("source_item")
op.drop_index("rss_feeds_tags_idx", table_name="rss_feeds", postgresql_using="gin")
op.drop_index("rss_feeds_active_idx", table_name="rss_feeds")
@ -340,3 +436,7 @@ def downgrade() -> None:
op.drop_index("email_accounts_address_idx", table_name="email_accounts")
op.drop_index("email_accounts_active_idx", table_name="email_accounts")
op.drop_table("email_accounts")
op.drop_index("book_title_idx", table_name="book")
op.drop_index("book_isbn_idx", table_name="book")
op.drop_index("book_author_idx", table_name="book")
op.drop_table("book")

View File

@ -92,14 +92,19 @@ def clean_filename(filename: str) -> str:
def image_filenames(chunk_id: str, images: list[Image.Image]) -> list[str]:
for i, image in enumerate(images):
if not image.filename: # type: ignore
filename = f"{chunk_id}_{i}.{image.format}" # type: ignore
filename = settings.CHUNK_STORAGE_DIR / f"{chunk_id}_{i}.{image.format}" # type: ignore
image.save(filename)
image.filename = str(filename) # type: ignore
return [image.filename for image in images] # type: ignore
def add_pics(chunk: str, images: list[Image.Image]) -> list[extract.MulitmodalChunk]:
return [chunk] + [i for i in images if i.filename in chunk] # type: ignore
return [chunk] + [
i
for i in images
if getattr(i, "filename", None) and i.filename in chunk # type: ignore
]
class Chunk(Base):
@ -114,7 +119,9 @@ class Chunk(Base):
source_id = Column(
BigInteger, ForeignKey("source_item.id", ondelete="CASCADE"), nullable=False
)
file_path = Column(Text) # Path to content if stored as a file
file_paths = Column(
ARRAY(Text), nullable=True
) # Path to content if stored as a file
content = Column(Text) # Direct content storage
embedding_model = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now())
@ -125,16 +132,16 @@ class Chunk(Base):
# One of file_path or content must be populated
__table_args__ = (
CheckConstraint("(file_path IS NOT NULL) OR (content IS NOT NULL)"),
CheckConstraint("(file_paths IS NOT NULL) OR (content IS NOT NULL)"),
Index("chunk_source_idx", "source_id"),
)
@property
def data(self) -> list[bytes | str | Image.Image]:
if self.file_path is None:
if self.file_paths is None:
return [cast(str, self.content)]
paths = [pathlib.Path(p) for p in self.file_path.split("\n")]
paths = [pathlib.Path(cast(str, p)) for p in self.file_paths]
files = [path for path in paths if path.exists()]
items = []
@ -205,16 +212,16 @@ class SourceItem(Base):
self, data: Sequence[extract.MulitmodalChunk], metadata: dict[str, Any] = {}
):
chunk_id = str(uuid.uuid4())
text = "\n\n".join(c for c in data if isinstance(c, str))
text = "\n\n".join(c for c in data if isinstance(c, str) and c.strip())
images = [c for c in data if isinstance(c, Image.Image)]
image_names = image_filenames(chunk_id, images)
chunk = Chunk(
id=chunk_id,
source=self,
content=text,
content=text or None,
images=images,
file_path="\n".join(image_names) if image_names else None,
file_paths=image_names,
embedding_model=collections.collection_model(cast(str, self.modality)),
item_metadata=self.as_payload() | metadata,
)
@ -555,7 +562,7 @@ class BookSection(SourceItem):
)
def as_payload(self) -> dict:
return {
vals = {
"source_id": self.id,
"book_id": self.book_id,
"section_title": self.section_title,
@ -565,14 +572,18 @@ class BookSection(SourceItem):
"end_page": self.end_page,
"tags": self.tags,
}
return {k: v for k, v in vals.items() if v}
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
if not cast(str, self.content.strip()):
return []
texts = [(page, i + self.start_page) for i, page in enumerate(self.pages)]
texts += [(cast(str, self.content), self.start_page)]
return [
self._make_chunk([text.strip()], metadata | {"page": page_number})
for text, page_number in texts
]
if text and text.strip()
] + [self._make_chunk([cast(str, self.content.strip())], metadata)]
class BlogPost(SourceItem):
@ -628,14 +639,15 @@ class BlogPost(SourceItem):
images = [Image.open(image) for image in self.images]
content = cast(str, self.content)
full_text = [content, *images]
full_text = [content.strip(), *images]
tokens = chunker.approx_token_count(cast(str, self.content))
if tokens < chunker.DEFAULT_CHUNK_TOKENS * 2:
return [full_text]
chunks = []
tokens = chunker.approx_token_count(content)
if tokens > chunker.DEFAULT_CHUNK_TOKENS * 2:
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
chunks = [add_pics(c, images) for c in chunker.chunk_text(content)]
return [full_text] + chunks
all_chunks = [full_text] + chunks
return [c for c in all_chunks if c and all(i for i in c)]
class MiscDoc(SourceItem):

View File

@ -1,5 +1,380 @@
from memory.common.db.models import SourceItem
from sqlalchemy.orm import Session
from unittest.mock import patch, Mock
from typing import cast
import pytest
from PIL import Image
from datetime import datetime
from memory.common import settings
from memory.common import chunker
from memory.common.db.models import (
Chunk,
clean_filename,
image_filenames,
add_pics,
MailMessage,
EmailAttachment,
BookSection,
BlogPost,
)
@pytest.fixture
def default_chunk_size():
chunk_length = chunker.DEFAULT_CHUNK_TOKENS
real_chunker = chunker.chunk_text
def chunk_text(text: str, max_tokens: int = 0):
max_tokens = max_tokens or chunk_length
return real_chunker(text, max_tokens=max_tokens)
def set_size(new_size: int):
nonlocal chunk_length
chunk_length = new_size
with patch.object(chunker, "chunk_text", chunk_text):
yield set_size
@pytest.mark.parametrize(
"input_filename,expected",
[
("normal_file.txt", "normal_file_txt"),
("file with spaces.pdf", "file_with_spaces_pdf"),
("file-with-dashes.doc", "file_with_dashes_doc"),
("file@#$%^&*()+={}[]|\\:;\"'<>,.?/~`", "file"),
("___multiple___underscores___", "multiple___underscores"),
("", ""),
("123", "123"),
("file.with.multiple.dots.txt", "file_with_multiple_dots_txt"),
],
)
def test_clean_filename(input_filename, expected):
assert clean_filename(input_filename) == expected
def test_image_filenames_with_existing_filenames(tmp_path):
"""Test image_filenames when images already have filenames"""
chunk_id = "test_chunk_123"
# Create actual test images and load them from files (which sets filename)
image1_path = tmp_path / "existing1.png"
image2_path = tmp_path / "existing2.jpg"
# Create and save images first
img1 = Image.new("RGB", (1, 1), color="red")
img1.save(image1_path)
img2 = Image.new("RGB", (1, 1), color="blue")
img2.save(image2_path)
# Load images from files (this sets the filename attribute)
image1 = Image.open(image1_path)
image2 = Image.open(image2_path)
images = [image1, image2]
result = image_filenames(chunk_id, images)
assert result == [str(image1_path), str(image2_path)]
def test_image_filenames_without_existing_filenames():
"""Test image_filenames when images don't have filenames"""
chunk_id = "test_chunk_456"
# Create actual test images without filenames
image1 = Image.new("RGB", (1, 1), color="red")
image1.format = "PNG"
# Manually set filename to None to simulate no filename
object.__setattr__(image1, "filename", None)
image2 = Image.new("RGB", (1, 1), color="blue")
image2.format = "JPEG"
object.__setattr__(image2, "filename", None)
images = [image1, image2]
result = image_filenames(chunk_id, images)
expected_filenames = [
str(settings.CHUNK_STORAGE_DIR / f"{chunk_id}_0.PNG"),
str(settings.CHUNK_STORAGE_DIR / f"{chunk_id}_1.JPEG"),
]
assert result == expected_filenames
assert (settings.CHUNK_STORAGE_DIR / f"{chunk_id}_0.PNG").exists()
assert (settings.CHUNK_STORAGE_DIR / f"{chunk_id}_1.JPEG").exists()
def test_add_pics():
"""Test add_pics function with mock-like behavior"""
chunk = "This is a test chunk with image1.png content"
image1 = Image.new("RGB", (1, 1), color="red")
object.__setattr__(image1, "filename", "image1.png")
image2 = Image.new("RGB", (1, 1), color="blue")
object.__setattr__(image2, "filename", "image2.jpg")
ignored_image = Image.new("RGB", (1, 1), color="blue")
images = [image1, image2, ignored_image]
result = add_pics(chunk, images)
# Should include the chunk and only images whose filename is in the chunk
assert result == [chunk, image1]
def test_chunk_data_property_content_only():
"""Test Chunk.data property when only content is set"""
source = SourceItem(sha256=b"test123", content="test", modality="text")
chunk = Chunk(source=source, content="Test content", embedding_model="test-model")
result = chunk.data
assert result == ["Test content"]
def test_chunk_data_property_with_files(tmp_path):
"""Test Chunk.data property when file_paths are set"""
# Create test files
text_file = tmp_path / "test.txt"
text_file.write_text("Text file content")
bin_file = tmp_path / "test.bin"
bin_file.write_bytes(b"Binary content")
image_file = tmp_path / "test.png"
# Create a simple 1x1 pixel PNG
img = Image.new("RGB", (1, 1), color="red")
img.save(image_file)
source = SourceItem(sha256=b"test123", content="test", modality="text")
chunk = Chunk(
source=source,
file_paths=[
str(text_file),
str(bin_file),
str(image_file),
"/missing/file.png",
],
embedding_model="test-model",
)
result = chunk.data
assert len(result) == 3
assert result[0] == "Text file content"
assert result[1] == b"Binary content"
assert isinstance(result[2], Image.Image)
@pytest.mark.parametrize(
"chunk_length, expected",
(
(
100000,
[
[
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."
]
],
),
(
10,
[
["Lorem ipsum dolor sit amet, consectetur adipiscing elit."],
["Sed do eiusmod tempor incididunt ut labore"],
["et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud"],
[
"et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation "
"ullamco laboris nisi ut"
],
[
"et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation "
"ullamco laboris nisi ut aliquip ex ea commodo consequat."
],
[
"et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation "
"ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure "
"dolor in reprehenderit in"
],
[
"ip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in "
"voluptate velit esse cillum dolore eu"
],
[
"ip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in "
"voluptate velit esse cillum dolore eu fugiat nulla pariatur."
],
[
"ip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in "
"voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint "
"occaecat cupidatat non"
],
[
"dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non "
"proident, sunt in culpa qui officia"
],
[
"dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non "
"proident, sunt in culpa qui officia deserunt mollit anim id est laborum."
],
[
"dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non "
"proident, sunt in culpa qui officia deserunt mollit anim id est laborum."
],
],
),
(
20,
[
[
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod "
"tempor incididunt ut labore et dolore magna aliqua."
],
[
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod "
"tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim "
"veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip"
],
[
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut "
"aliquip ex ea commodo consequat."
],
[
"Duis aute irure dolor in reprehenderit in voluptate velit esse cillum "
"dolore eu fugiat nulla pariatur."
],
[
"Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia "
"deserunt"
],
["mollit anim id est laborum."],
],
),
),
)
def test_source_item_chunk_contents_text(chunk_length, expected, default_chunk_size):
"""Test SourceItem._chunk_contents for text content"""
source = SourceItem(
sha256=b"test123",
content="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
modality="text",
)
default_chunk_size(chunk_length)
assert source._chunk_contents() == expected
def test_source_item_chunk_contents_image(tmp_path):
"""Test SourceItem._chunk_contents for image content"""
image_file = tmp_path / "test.png"
img = Image.new("RGB", (10, 10), color="red")
img.save(image_file)
source = SourceItem(
sha256=b"test123",
filename=str(image_file),
modality="image",
mime_type="image/png",
)
result = source._chunk_contents()
assert len(result) == 1
assert len(result[0]) == 1
assert isinstance(result[0][0], Image.Image)
def test_source_item_chunk_contents_mixed(tmp_path):
"""Test SourceItem._chunk_contents for image content"""
image_file = tmp_path / "test.png"
img = Image.new("RGB", (10, 10), color="red")
img.save(image_file)
source = SourceItem(
sha256=b"test123",
content="Bla bla",
filename=str(image_file),
modality="image",
mime_type="image/png",
)
result = source._chunk_contents()
assert len(result) == 2
assert result[0][0] == "Bla bla"
assert isinstance(result[1][0], Image.Image)
@pytest.mark.parametrize(
"texts, expected_content",
(
([], None),
(["", " \n ", " "], None),
(["Hello"], "Hello"),
(["Hello", "World"], "Hello\n\nWorld"),
(["Hello", "World", ""], "Hello\n\nWorld"),
(["Hello", "World", "", ""], "Hello\n\nWorld"),
(["Hello", "World", "", "", ""], "Hello\n\nWorld"),
(["Hello", "World", "", "", "", ""], "Hello\n\nWorld"),
(["Hello", "World", "", "", "", "", "bla"], "Hello\n\nWorld\n\nbla"),
),
)
def test_source_item_make_chunk(tmp_path, texts, expected_content):
"""Test SourceItem._make_chunk method"""
source = SourceItem(
sha256=b"test123", content="test", modality="text", tags=["tag1"]
)
# Create actual image
image_file = tmp_path / "test.png"
img = Image.new("RGB", (1, 1), color="red")
img.save(image_file)
# Use object.__setattr__ to set filename
object.__setattr__(img, "filename", str(image_file))
data = [*texts, img]
metadata = {"extra": "data"}
chunk = source._make_chunk(data, metadata)
assert chunk.id is not None
assert chunk.source == source
assert cast(str, chunk.content) == expected_content
assert cast(list[str], chunk.file_paths) == [str(image_file)]
assert chunk.embedding_model is not None
# Check that metadata is merged correctly
expected_payload = {"source_id": source.id, "tags": ["tag1"], "extra": "data"}
assert chunk.item_metadata == expected_payload
def test_source_item_as_payload():
source = SourceItem(
id=123,
sha256=b"test123",
content="test",
modality="text",
tags=["tag1", "tag2"],
)
payload = source.as_payload()
assert payload == {"source_id": 123, "tags": ["tag1", "tag2"]}
@pytest.mark.parametrize(
"content,filename,expected",
[
("Test content", None, "Test content"),
(None, "test.txt", "test.txt"),
("Test content", "test.txt", "Test content"), # content takes precedence
(None, None, None),
],
)
def test_source_item_display_contents(content, filename, expected):
"""Test SourceItem.display_contents property"""
source = SourceItem(
sha256=b"test123", content=content, filename=filename, modality="text"
)
assert source.display_contents == expected
def test_unique_source_items_same_commit(db_session: Session):
@ -43,3 +418,484 @@ def test_unique_source_items_previous_commit(db_session: Session):
(b"1234567893", "test5"),
(b"1234567894", "test6"),
]
def test_source_item_chunk_contents_empty_content():
"""Test SourceItem._chunk_contents with empty content"""
source = SourceItem(sha256=b"test123", content=None, modality="text")
assert source._chunk_contents() == []
def test_source_item_chunk_contents_no_mime_type(tmp_path):
"""Test SourceItem._chunk_contents with filename but no mime_type"""
image_file = tmp_path / "test.png"
img = Image.new("RGB", (10, 10), color="red")
img.save(image_file)
source = SourceItem(
sha256=b"test123", filename=str(image_file), modality="image", mime_type=None
)
assert source._chunk_contents() == []
@pytest.mark.parametrize(
"content,file_paths,description",
[
("Test content", None, "content is set"),
(None, ["test.txt"], "file_paths is set"),
],
)
def test_chunk_constraint_validation(
db_session: Session, content, file_paths, description
):
"""Test that Chunk enforces the constraint that either file_paths or content must be set"""
source = SourceItem(sha256=b"test123", content="test", modality="text")
db_session.add(source)
db_session.commit()
chunk = Chunk(
source=source,
content=content,
file_paths=file_paths,
embedding_model="test-model",
)
db_session.add(chunk)
db_session.commit()
assert chunk.id is not None
@pytest.mark.parametrize(
"modality,expected_modality",
[
(None, "email"), # Default case
("custom", "custom"), # Override case
],
)
def test_mail_message_modality(modality, expected_modality):
"""Test MailMessage modality setting"""
kwargs = {"sha256": b"test", "content": "test"}
if modality is not None:
kwargs["modality"] = modality
mail_message = MailMessage(**kwargs)
# The __init__ method should set the correct modality
assert hasattr(mail_message, "modality")
@pytest.mark.parametrize(
"sender,folder,expected_path",
[
("user@example.com", "INBOX", "user_example_com/INBOX"),
("user+tag@example.com", "Sent Items", "user_tag_example_com/Sent_Items"),
("user@domain.co.uk", None, "user_domain_co_uk/INBOX"),
("user@domain.co.uk", "", "user_domain_co_uk/INBOX"),
],
)
def test_mail_message_attachments_path(sender, folder, expected_path):
"""Test MailMessage.attachments_path property"""
mail_message = MailMessage(
sha256=b"test", content="test", sender=sender, folder=folder
)
with patch.object(settings, "FILE_STORAGE_DIR", "/tmp/storage"):
result = mail_message.attachments_path
assert str(result) == f"/tmp/storage/{expected_path}"
@pytest.mark.parametrize(
"filename,expected",
[
("document.pdf", "document.pdf"),
("file with spaces.txt", "file_with_spaces.txt"),
("file@#$%^&*().doc", "file.doc"),
("no-extension", "no_extension"),
("multiple.dots.in.name.txt", "multiple_dots_in_name.txt"),
],
)
def test_mail_message_safe_filename(tmp_path, filename, expected):
"""Test MailMessage.safe_filename method"""
mail_message = MailMessage(
sha256=b"test", content="test", sender="user@example.com", folder="INBOX"
)
with patch.object(settings, "FILE_STORAGE_DIR", tmp_path):
result = mail_message.safe_filename(filename)
# Check that the path is correct
expected_path = tmp_path / "user_example_com" / "INBOX" / expected
assert result == expected_path
# Check that the directory was created
assert result.parent.exists()
@pytest.mark.parametrize(
"sent_at,expected_date",
[
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
(None, None),
],
)
def test_mail_message_as_payload(sent_at, expected_date):
"""Test MailMessage.as_payload method"""
mail_message = MailMessage(
sha256=b"test",
content="test",
message_id="<test@example.com>",
subject="Test Subject",
sender="sender@example.com",
recipients=["recipient1@example.com", "recipient2@example.com"],
folder="INBOX",
sent_at=sent_at,
tags=["tag1", "tag2"],
)
# Manually set id for testing
object.__setattr__(mail_message, "id", 123)
payload = mail_message.as_payload()
expected = {
"source_id": 123,
"message_id": "<test@example.com>",
"subject": "Test Subject",
"sender": "sender@example.com",
"recipients": ["recipient1@example.com", "recipient2@example.com"],
"folder": "INBOX",
"tags": [
"tag1",
"tag2",
"sender@example.com",
"recipient1@example.com",
"recipient2@example.com",
],
"date": expected_date,
}
assert payload == expected
def test_mail_message_parsed_content():
"""Test MailMessage.parsed_content property with actual email parsing"""
# Use a simple email format that the parser can handle
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
result = mail_message.parsed_content
# Just test that it returns a dict-like object
assert isinstance(result, dict)
assert "body" in result
def test_mail_message_body_property():
"""Test MailMessage.body property with actual email parsing"""
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
assert mail_message.body == "Test Body Content"
def test_mail_message_display_contents():
"""Test MailMessage.display_contents property with actual email parsing"""
email_content = """From: sender@example.com
To: recipient@example.com
Subject: Test Subject
Test Body Content"""
mail_message = MailMessage(
sha256=b"test", content=email_content, message_id="<test@example.com>"
)
expected = (
"\nSubject: Test Subject\nFrom: \nTo: \nDate: \nBody: \nTest Body Content\n"
)
assert mail_message.display_contents == expected
@pytest.mark.parametrize(
"created_at,expected_date",
[
(datetime(2023, 1, 1, 12, 0, 0), "2023-01-01T12:00:00"),
(None, None),
],
)
def test_email_attachment_as_payload(created_at, expected_date):
"""Test EmailAttachment.as_payload method"""
attachment = EmailAttachment(
sha256=b"test",
filename="document.pdf",
mime_type="application/pdf",
size=1024,
mail_message_id=123,
created_at=created_at,
tags=["pdf", "document"],
)
# Manually set id for testing
object.__setattr__(attachment, "id", 456)
payload = attachment.as_payload()
expected = {
"filename": "document.pdf",
"content_type": "application/pdf",
"size": 1024,
"created_at": expected_date,
"mail_message_id": 123,
"source_id": 456,
"tags": ["pdf", "document"],
}
assert payload == expected
@pytest.mark.parametrize(
"has_filename,content_source,expected_content",
[
(True, "file", b"test file content"),
(False, "content", "attachment content"),
],
)
@patch("memory.common.extract.extract_data_chunks")
def test_email_attachment_data_chunks(
mock_extract, has_filename, content_source, expected_content, tmp_path
):
"""Test EmailAttachment.data_chunks method"""
from memory.common.extract import DataChunk
mock_extract.return_value = [
DataChunk(data=["extracted text"], metadata={"source": content_source})
]
if has_filename:
# Create a test file
test_file = tmp_path / "test.txt"
test_file.write_bytes(b"test file content")
attachment = EmailAttachment(
sha256=b"test",
filename=str(test_file),
mime_type="text/plain",
mail_message_id=123,
)
else:
attachment = EmailAttachment(
sha256=b"test",
content="attachment content",
filename=None,
mime_type="text/plain",
mail_message_id=123,
)
# Mock _make_chunk to return a simple chunk
mock_chunk = Mock()
with patch.object(attachment, "_make_chunk", return_value=mock_chunk) as mock_make:
result = attachment.data_chunks({"extra": "metadata"})
# Verify the method calls
mock_extract.assert_called_once_with("text/plain", expected_content)
mock_make.assert_called_once_with(
["extracted text"], {"extra": "metadata", "source": content_source}
)
assert result == [mock_chunk]
def test_email_attachment_cascade_delete(db_session: Session):
"""Test that EmailAttachment is deleted when MailMessage is deleted"""
mail_message = MailMessage(
sha256=b"test_email",
content="test email",
message_id="<test@example.com>",
subject="Test",
sender="sender@example.com",
recipients=["recipient@example.com"],
folder="INBOX",
)
db_session.add(mail_message)
db_session.commit()
attachment = EmailAttachment(
sha256=b"test_attachment",
content="attachment content",
mail_message=mail_message,
filename="test.txt",
mime_type="text/plain",
size=100,
modality="attachment", # Set modality explicitly
)
db_session.add(attachment)
db_session.commit()
attachment_id = attachment.id
# Delete the mail message
db_session.delete(mail_message)
db_session.commit()
# Verify the attachment was also deleted
deleted_attachment = (
db_session.query(EmailAttachment).filter_by(id=attachment_id).first()
)
assert deleted_attachment is None
# BookSection tests
@pytest.mark.parametrize(
"pages,expected_chunks",
[
# No pages
([], []),
# Single page
(["Page 1 content"], [("Page 1 content", 10)]),
# Multiple pages
(
["Page 1", "Page 2", "Page 3"],
[
("Page 1", 10),
("Page 2", 11),
("Page 3", 12),
],
),
# Empty/whitespace pages filtered out
(["", " ", "Page 3"], [("Page 3", 12)]),
# All empty - no chunks created
(["", " ", " "], []),
],
)
def test_book_section_data_chunks(pages, expected_chunks):
"""Test BookSection.data_chunks with various page combinations"""
content = "\n\n".join(pages).strip()
book_section = BookSection(
sha256=b"test_section",
content=content,
modality="book",
book_id=1,
start_page=10,
end_page=10 + len(pages),
pages=pages,
)
chunks = book_section.data_chunks()
expected = [
(p, book_section.as_payload() | {"page": i}) for p, i in expected_chunks
]
if content:
expected.append((content, book_section.as_payload()))
assert [(c.content, c.item_metadata) for c in chunks] == expected
for c in chunks:
assert cast(list, c.file_paths) == []
@pytest.mark.parametrize(
"content,expected",
[
("", []),
("Short content", [["Short content"]]),
(
"This is a very long piece of content that should be chunked into multiple pieces when processed.",
[
[
"This is a very long piece of content that should be chunked into multiple pieces when processed."
],
["This is a very long piece of content that"],
["should be chunked into multiple pieces when"],
["processed."],
],
),
],
)
def test_blog_post_chunk_contents(content, expected, default_chunk_size):
default_chunk_size(10)
blog_post = BlogPost(
sha256=b"test_blog",
content=content,
modality="blog",
url="https://example.com/post",
images=[],
)
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
assert blog_post._chunk_contents() == expected
def test_blog_post_chunk_contents_with_images(tmp_path):
"""Test BlogPost._chunk_contents with images"""
# Create test image files
img1_path = tmp_path / "img1.jpg"
img2_path = tmp_path / "img2.jpg"
for img_path in [img1_path, img2_path]:
img = Image.new("RGB", (10, 10), color="red")
img.save(img_path)
blog_post = BlogPost(
sha256=b"test_blog",
content="Content with images",
modality="blog",
url="https://example.com/post",
images=[str(img1_path), str(img2_path)],
)
result = blog_post._chunk_contents()
result = [
[i if isinstance(i, str) else getattr(i, "filename") for i in c] for c in result
]
assert result == [
["Content with images", img1_path.as_posix(), img2_path.as_posix()]
]
def test_blog_post_chunk_contents_with_image_long_content(tmp_path, default_chunk_size):
default_chunk_size(10)
img1_path = tmp_path / "img1.jpg"
img2_path = tmp_path / "img2.jpg"
for img_path in [img1_path, img2_path]:
img = Image.new("RGB", (10, 10), color="red")
img.save(img_path)
blog_post = BlogPost(
sha256=b"test_blog",
content=f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
modality="blog",
url="https://example.com/post",
images=[str(img1_path), str(img2_path)],
)
with patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10):
result = blog_post._chunk_contents()
result = [
[i if isinstance(i, str) else getattr(i, "filename") for i in c] for c in result
]
print(result)
assert result == [
[
f"First picture is here: {img1_path.as_posix()}\nSecond picture is here: {img2_path.as_posix()}",
img1_path.as_posix(),
img2_path.as_posix(),
],
[
f"First picture is here: {img1_path.as_posix()}",
img1_path.as_posix(),
],
[
f"Second picture is here: {img2_path.as_posix()}",
img2_path.as_posix(),
],
]

View File

@ -100,7 +100,7 @@ def test_find_new_urls_parse_error(mock_parse):
@patch("memory.workers.tasks.comic.feedparser.parse")
def test_find_new_urls_empty_feed(mock_parse):
def test_find_new_urls_empty_feed(mock_parse, db_session):
"""Test handling of empty RSS feed."""
mock_parse.return_value = Mock(entries=[])
@ -110,7 +110,7 @@ def test_find_new_urls_empty_feed(mock_parse):
@patch("memory.workers.tasks.comic.feedparser.parse")
def test_find_new_urls_malformed_entries(mock_parse):
def test_find_new_urls_malformed_entries(mock_parse, db_session):
"""Test handling of malformed RSS entries."""
mock_parse.return_value = Mock(
entries=[

View File

@ -51,16 +51,16 @@ def chunk(request, test_image, db_session):
collection = request.param
if collection == "photo":
content = None
file_path = str(test_image)
file_paths = [str(test_image)]
else:
content = "Test content for reingestion"
file_path = None
file_paths = None
chunk = Chunk(
id=str(uuid.uuid4()),
source=SourceItem(id=1, modality=collection, sha256=b"123"),
content=content,
file_path=file_path,
file_paths=file_paths,
embedding_model="test-model",
checked_at=datetime(2025, 1, 1),
)
@ -146,9 +146,7 @@ def test_reingest_chunk(db_session, qdrant, chunk):
start = datetime.now()
test_vector = [0.1] * 1024
with patch.object(embedding, "embed_chunks", return_value=[test_vector]):
reingest_chunk(str(chunk.id), collection)
reingest_chunk(str(chunk.id), collection)
vectors = qd.search_vectors(qdrant, collection, test_vector, limit=1)
assert len(vectors) == 1
@ -198,7 +196,7 @@ def test_check_batch(db_session, qdrant):
id=f"00000000-0000-0000-0000-0000000000{i:02d}",
source=SourceItem(modality=modality, sha256=f"123{i}".encode()),
content="Test content",
file_path=None,
file_paths=None,
embedding_model="test-model",
checked_at=datetime(2025, 1, 1),
)
@ -252,7 +250,7 @@ def test_reingest_missing_chunks(db_session, qdrant, batch_size):
id=next(ids_generator),
source=SourceItem(modality=modality, sha256=f"{modality}-{i}".encode()),
content="Old content",
file_path=None,
file_paths=None,
embedding_model="test-model",
checked_at=old_time,
)
@ -267,7 +265,7 @@ def test_reingest_missing_chunks(db_session, qdrant, batch_size):
modality=modality, sha256=f"recent-{modality}-{i}".encode()
),
content="Recent content",
file_path=None,
file_paths=None,
embedding_model="test-model",
checked_at=now,
)