mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 15:14:45 +02:00
add missing tests
This commit is contained in:
parent
50d0eb97db
commit
01ccea2733
@ -347,7 +347,7 @@ class SourceItem(Base):
|
||||
collection_name=modality,
|
||||
embedding_model=collections.collection_model(modality, text, images),
|
||||
item_metadata=extract.merge_metadata(
|
||||
self.as_payload(), data.metadata, metadata
|
||||
cast(dict[str, Any], self.as_payload()), data.metadata, metadata
|
||||
),
|
||||
)
|
||||
return chunk
|
||||
|
@ -15,7 +15,6 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
|
@ -547,3 +547,141 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session):
|
||||
# Verify both the MailMessage and SourceItem records are deleted
|
||||
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
|
||||
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content,image_paths,expected_chunks",
|
||||
[
|
||||
("", [], 0), # Empty content returns empty list
|
||||
(" \n ", [], 0), # Whitespace-only content returns empty list
|
||||
("Short content", [], 1), # Short content returns just full_text chunk
|
||||
("A" * 10, [], 1), # Very short content returns just full_text chunk
|
||||
],
|
||||
)
|
||||
def test_chunk_mixed_basic_cases(tmp_path, content, image_paths, expected_chunks):
|
||||
"""Test chunk_mixed function with basic cases"""
|
||||
from memory.common.db.models.source_item import chunk_mixed
|
||||
|
||||
# Create test images if needed
|
||||
actual_image_paths = []
|
||||
for i, _ in enumerate(image_paths):
|
||||
image_file = tmp_path / f"test{i}.png"
|
||||
img = Image.new("RGB", (1, 1), color="red")
|
||||
img.save(image_file)
|
||||
actual_image_paths.append(image_file.name)
|
||||
|
||||
# Mock settings.FILE_STORAGE_DIR to point to tmp_path
|
||||
with patch.object(settings, "FILE_STORAGE_DIR", tmp_path):
|
||||
result = chunk_mixed(content, actual_image_paths)
|
||||
|
||||
assert len(result) == expected_chunks
|
||||
|
||||
|
||||
def test_chunk_mixed_with_images(tmp_path):
|
||||
"""Test chunk_mixed function with images"""
|
||||
from memory.common.db.models.source_item import chunk_mixed
|
||||
|
||||
# Create test images
|
||||
image1 = tmp_path / "image1.png"
|
||||
image2 = tmp_path / "image2.jpg"
|
||||
Image.new("RGB", (1, 1), color="red").save(image1)
|
||||
Image.new("RGB", (1, 1), color="blue").save(image2)
|
||||
|
||||
content = "This content mentions image1.png and image2.jpg"
|
||||
image_paths = [image1.name, image2.name]
|
||||
|
||||
with patch.object(settings, "FILE_STORAGE_DIR", tmp_path):
|
||||
result = chunk_mixed(content, image_paths)
|
||||
|
||||
assert len(result) >= 1
|
||||
# First chunk should contain the full text and images
|
||||
assert content.strip() in result[0].data
|
||||
assert len([d for d in result[0].data if isinstance(d, Image.Image)]) == 2
|
||||
|
||||
|
||||
def test_chunk_mixed_long_content(tmp_path):
|
||||
"""Test chunk_mixed function with long content that gets chunked"""
|
||||
from memory.common.db.models.source_item import chunk_mixed
|
||||
|
||||
# Create long content
|
||||
long_content = "Lorem ipsum dolor sit amet, " * 50 # About 150 words
|
||||
|
||||
# Mock the chunker functions to force chunking behavior
|
||||
with (
|
||||
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
|
||||
patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10),
|
||||
patch.object(chunker, "approx_token_count", return_value=100),
|
||||
): # Force it to be > 2 * 10
|
||||
result = chunk_mixed(long_content, [])
|
||||
|
||||
# Should have multiple chunks: full_text + chunked pieces + summary
|
||||
assert len(result) > 1
|
||||
|
||||
# First chunk should be full text
|
||||
assert long_content.strip() in result[0].data
|
||||
|
||||
# Last chunk should be summary
|
||||
# (we can't easily test the exact summary without mocking summarizer)
|
||||
assert result[-1].data # Should have some data
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sha256_values,expected_committed",
|
||||
[
|
||||
([b"unique1", b"unique2", b"unique3"], 3), # All unique
|
||||
([b"duplicate", b"duplicate", b"unique"], 2), # One duplicate pair
|
||||
([b"same", b"same", b"same"], 1), # All duplicates
|
||||
([b"dup1", b"dup1", b"dup2", b"dup2"], 2), # Two duplicate pairs
|
||||
],
|
||||
)
|
||||
def test_handle_duplicate_sha256_behavior(
|
||||
db_session: Session, sha256_values, expected_committed
|
||||
):
|
||||
"""Test that handle_duplicate_sha256 event listener prevents duplicate sha256 values"""
|
||||
# Create SourceItems with the given sha256 values
|
||||
items = []
|
||||
for i, sha256 in enumerate(sha256_values):
|
||||
item = SourceItem(sha256=sha256, content=f"test content {i}", modality="text")
|
||||
items.append(item)
|
||||
db_session.add(item)
|
||||
|
||||
# Commit should trigger the event listener
|
||||
db_session.commit()
|
||||
|
||||
# Query how many items were actually committed
|
||||
committed_count = db_session.query(SourceItem).count()
|
||||
assert committed_count == expected_committed
|
||||
|
||||
# Verify all sha256 values in database are unique
|
||||
sha256_in_db = [row[0] for row in db_session.query(SourceItem.sha256).all()]
|
||||
assert len(sha256_in_db) == len(set(sha256_in_db)) # All unique
|
||||
|
||||
|
||||
def test_handle_duplicate_sha256_with_existing_data(db_session: Session):
|
||||
"""Test duplicate handling when items already exist in database"""
|
||||
# Add initial items
|
||||
existing_item = SourceItem(sha256=b"existing", content="original", modality="text")
|
||||
db_session.add(existing_item)
|
||||
db_session.commit()
|
||||
|
||||
# Try to add new items with same and different sha256
|
||||
new_items = [
|
||||
SourceItem(
|
||||
sha256=b"existing", content="duplicate", modality="text"
|
||||
), # Should be rejected
|
||||
SourceItem(
|
||||
sha256=b"new_unique", content="new content", modality="text"
|
||||
), # Should be kept
|
||||
]
|
||||
for item in new_items:
|
||||
db_session.add(item)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Should have 2 items total (original + new unique)
|
||||
assert db_session.query(SourceItem).count() == 2
|
||||
|
||||
# Original content should be preserved
|
||||
existing_in_db = db_session.query(SourceItem).filter_by(sha256=b"existing").first()
|
||||
assert existing_in_db is not None
|
||||
assert str(existing_in_db.content) == "original" # Original should be preserved
|
||||
|
104
tests/memory/common/db/models/test_users.py
Normal file
104
tests/memory/common/db/models/test_users.py
Normal file
@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
from memory.common.db.models.users import hash_password, verify_password
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"password",
|
||||
[
|
||||
"simple_password",
|
||||
"complex_P@ssw0rd!",
|
||||
"very_long_password_with_many_characters_1234567890",
|
||||
"",
|
||||
"unicode_password_тест_😀",
|
||||
"password with spaces",
|
||||
],
|
||||
)
|
||||
def test_hash_password_format(password):
|
||||
"""Test that hash_password returns correctly formatted hash"""
|
||||
result = hash_password(password)
|
||||
|
||||
# Should be in format "salt:hash"
|
||||
assert ":" in result
|
||||
parts = result.split(":", 1)
|
||||
assert len(parts) == 2
|
||||
|
||||
salt, hash_value = parts
|
||||
# Salt should be 32 hex characters (16 bytes * 2)
|
||||
assert len(salt) == 32
|
||||
assert all(c in "0123456789abcdef" for c in salt)
|
||||
|
||||
# Hash should be 64 hex characters (SHA-256 = 32 bytes * 2)
|
||||
assert len(hash_value) == 64
|
||||
assert all(c in "0123456789abcdef" for c in hash_value)
|
||||
|
||||
|
||||
def test_hash_password_uniqueness():
|
||||
"""Test that same password generates different hashes due to random salt"""
|
||||
password = "test_password"
|
||||
hash1 = hash_password(password)
|
||||
hash2 = hash_password(password)
|
||||
|
||||
# Different salts should produce different hashes
|
||||
assert hash1 != hash2
|
||||
|
||||
# But both should verify correctly
|
||||
assert verify_password(password, hash1)
|
||||
assert verify_password(password, hash2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"password,expected",
|
||||
[
|
||||
("correct_password", True),
|
||||
("wrong_password", False),
|
||||
("", False),
|
||||
("CORRECT_PASSWORD", False), # Case sensitive
|
||||
],
|
||||
)
|
||||
def test_verify_password_correctness(password, expected):
|
||||
"""Test password verification with correct and incorrect passwords"""
|
||||
correct_password = "correct_password"
|
||||
password_hash = hash_password(correct_password)
|
||||
|
||||
result = verify_password(password, password_hash)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"malformed_hash",
|
||||
[
|
||||
"invalid_format",
|
||||
"no_colon_here",
|
||||
":empty_salt",
|
||||
"salt:", # Empty hash
|
||||
"",
|
||||
"too:many:colons:here",
|
||||
"salt:invalid_hex_zzz",
|
||||
"salt:too_short_hash",
|
||||
],
|
||||
)
|
||||
def test_verify_password_malformed_hash(malformed_hash):
|
||||
"""Test that verify_password handles malformed hashes gracefully"""
|
||||
result = verify_password("any_password", malformed_hash)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_password",
|
||||
[
|
||||
"simple",
|
||||
"complex_P@ssw0rd!123",
|
||||
"",
|
||||
"unicode_тест_😀",
|
||||
"password with spaces and symbols !@#$%^&*()",
|
||||
],
|
||||
)
|
||||
def test_hash_verify_roundtrip(test_password):
|
||||
"""Test that hash and verify work correctly together"""
|
||||
password_hash = hash_password(test_password)
|
||||
|
||||
# Correct password should verify
|
||||
assert verify_password(test_password, password_hash)
|
||||
|
||||
# Wrong password should not verify
|
||||
assert not verify_password(test_password + "_wrong", password_hash)
|
@ -1,23 +1,30 @@
|
||||
from unittest.mock import Mock
|
||||
import pytest
|
||||
from typing import cast
|
||||
from PIL import Image
|
||||
|
||||
from memory.common import collections
|
||||
from memory.common import collections, settings
|
||||
from memory.common.embedding import (
|
||||
as_string,
|
||||
embed_chunks,
|
||||
embed_mixed,
|
||||
embed_text,
|
||||
break_chunk,
|
||||
embed_by_model,
|
||||
)
|
||||
from memory.common.extract import DataChunk
|
||||
from memory.common.extract import DataChunk, MulitmodalChunk
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed(mock_voyage_client):
|
||||
vectors = ([i] for i in range(1000))
|
||||
|
||||
def embed(texts, model, input_type):
|
||||
def embed_func(texts, model, input_type):
|
||||
return Mock(embeddings=[next(vectors) for _ in texts])
|
||||
|
||||
mock_voyage_client.embed = embed
|
||||
mock_voyage_client.multimodal_embed = embed
|
||||
mock_voyage_client.embed = Mock(side_effect=embed_func)
|
||||
mock_voyage_client.multimodal_embed = Mock(side_effect=embed_func)
|
||||
|
||||
return mock_voyage_client
|
||||
|
||||
@ -52,3 +59,182 @@ def test_embed_text(mock_embed):
|
||||
def test_embed_mixed(mock_embed):
|
||||
items = [DataChunk(data=["text"])]
|
||||
assert embed_mixed(items) == [[0]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_data, expected_output",
|
||||
[
|
||||
("hello world", "hello world"),
|
||||
(" hello world \n", "hello world"),
|
||||
(
|
||||
cast(list[MulitmodalChunk], ["first chunk", "second chunk", "third chunk"]),
|
||||
"first chunk\nsecond chunk\nthird chunk",
|
||||
),
|
||||
(cast(list[MulitmodalChunk], []), ""),
|
||||
(
|
||||
cast(list[MulitmodalChunk], ["", "valid text", " ", "another text"]),
|
||||
"valid text\n\nanother text",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_as_string_basic_cases(input_data, expected_output):
|
||||
assert as_string(input_data) == expected_output
|
||||
|
||||
|
||||
def test_as_string_with_nested_lists():
|
||||
# This tests the recursive nature of as_string - kept separate due to different input type
|
||||
chunks = [["nested", "items"], "single item"]
|
||||
result = as_string(chunks)
|
||||
assert result == "nested\nitems\nsingle item"
|
||||
|
||||
|
||||
def test_embed_chunks_with_text_model(mock_embed):
|
||||
chunks = cast(list[list[MulitmodalChunk]], [["text1"], ["text2"]])
|
||||
result = embed_chunks(chunks, model=settings.TEXT_EMBEDDING_MODEL)
|
||||
assert result == [[0], [1]]
|
||||
mock_embed.embed.assert_called_once_with(
|
||||
["text1", "text2"],
|
||||
model=settings.TEXT_EMBEDDING_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
|
||||
|
||||
def test_embed_chunks_with_mixed_model(mock_embed):
|
||||
chunks = cast(list[list[MulitmodalChunk]], [["text with image"], ["another chunk"]])
|
||||
result = embed_chunks(chunks, model=settings.MIXED_EMBEDDING_MODEL)
|
||||
assert result == [[0], [1]]
|
||||
mock_embed.multimodal_embed.assert_called_once_with(
|
||||
chunks, model=settings.MIXED_EMBEDDING_MODEL, input_type="document"
|
||||
)
|
||||
|
||||
|
||||
def test_embed_chunks_with_query_input_type(mock_embed):
|
||||
chunks = cast(list[list[MulitmodalChunk]], [["query text"]])
|
||||
result = embed_chunks(chunks, input_type="query")
|
||||
assert result == [[0]]
|
||||
mock_embed.embed.assert_called_once_with(
|
||||
["query text"], model=settings.TEXT_EMBEDDING_MODEL, input_type="query"
|
||||
)
|
||||
|
||||
|
||||
def test_embed_chunks_empty_list(mock_embed):
|
||||
result = embed_chunks([])
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, chunk_size, expected_result",
|
||||
[
|
||||
(["short text"], 100, ["short text"]),
|
||||
(["some text content"], 200, ["some text content"]),
|
||||
([], 100, []),
|
||||
],
|
||||
)
|
||||
def test_break_chunk_simple_cases(data, chunk_size, expected_result):
|
||||
chunk = DataChunk(data=data)
|
||||
result = break_chunk(chunk, chunk_size=chunk_size)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_break_chunk_with_long_text():
|
||||
# Create text that will exceed chunk size
|
||||
long_text = "word " * 200 # Should be much longer than default chunk size
|
||||
chunk = DataChunk(data=[long_text])
|
||||
result = break_chunk(chunk, chunk_size=50)
|
||||
|
||||
# Should be broken into multiple chunks
|
||||
assert len(result) > 1
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
|
||||
|
||||
def test_break_chunk_with_mixed_data_types():
|
||||
# Mock image object
|
||||
mock_image = Mock(spec=Image.Image)
|
||||
chunk = DataChunk(data=["text content", mock_image])
|
||||
result = break_chunk(chunk, chunk_size=100)
|
||||
|
||||
# Should have text chunks plus the original chunk (since it's not a string)
|
||||
assert len(result) >= 2
|
||||
assert any(isinstance(item, str) for item in result)
|
||||
# The original chunk should be preserved when it contains mixed data
|
||||
assert chunk in result
|
||||
|
||||
|
||||
def test_embed_by_model_with_matching_chunks(mock_embed):
|
||||
# Create mock chunks with specific embedding model
|
||||
chunk1 = Mock(spec=Chunk)
|
||||
chunk1.embedding_model = "test-model"
|
||||
chunk1.chunks = ["chunk1 content"]
|
||||
|
||||
chunk2 = Mock(spec=Chunk)
|
||||
chunk2.embedding_model = "test-model"
|
||||
chunk2.chunks = ["chunk2 content"]
|
||||
|
||||
chunks = cast(list[Chunk], [chunk1, chunk2])
|
||||
result = embed_by_model(chunks, "test-model")
|
||||
|
||||
assert len(result) == 2
|
||||
assert chunk1.vector == [0]
|
||||
assert chunk2.vector == [1]
|
||||
assert result == [chunk1, chunk2]
|
||||
|
||||
|
||||
def test_embed_by_model_with_no_matching_chunks(mock_embed):
|
||||
chunk1 = Mock(spec=Chunk)
|
||||
chunk1.embedding_model = "different-model"
|
||||
# Ensure the chunk doesn't have a vector initially
|
||||
del chunk1.vector
|
||||
|
||||
chunks = cast(list[Chunk], [chunk1])
|
||||
result = embed_by_model(chunks, "test-model")
|
||||
|
||||
assert result == []
|
||||
assert not hasattr(chunk1, "vector")
|
||||
|
||||
|
||||
def test_embed_by_model_with_mixed_models(mock_embed):
|
||||
chunk1 = Mock(spec=Chunk)
|
||||
chunk1.embedding_model = "test-model"
|
||||
chunk1.chunks = ["chunk1 content"]
|
||||
|
||||
chunk2 = Mock(spec=Chunk)
|
||||
chunk2.embedding_model = "other-model"
|
||||
chunk2.chunks = ["chunk2 content"]
|
||||
|
||||
chunk3 = Mock(spec=Chunk)
|
||||
chunk3.embedding_model = "test-model"
|
||||
chunk3.chunks = ["chunk3 content"]
|
||||
|
||||
chunks = cast(list[Chunk], [chunk1, chunk2, chunk3])
|
||||
result = embed_by_model(chunks, "test-model")
|
||||
|
||||
assert len(result) == 2
|
||||
assert chunk1 in result
|
||||
assert chunk3 in result
|
||||
assert chunk2 not in result
|
||||
assert chunk1.vector == [0]
|
||||
assert chunk3.vector == [1]
|
||||
|
||||
|
||||
def test_embed_by_model_with_empty_chunks(mock_embed):
|
||||
result = embed_by_model([], "test-model")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_embed_by_model_calls_embed_chunks_correctly(mock_embed):
|
||||
chunk1 = Mock(spec=Chunk)
|
||||
chunk1.embedding_model = "test-model"
|
||||
chunk1.chunks = ["content1"]
|
||||
|
||||
chunk2 = Mock(spec=Chunk)
|
||||
chunk2.embedding_model = "test-model"
|
||||
chunk2.chunks = ["content2"]
|
||||
|
||||
chunks = cast(list[Chunk], [chunk1, chunk2])
|
||||
embed_by_model(chunks, "test-model")
|
||||
|
||||
# Verify embed_chunks was called with the right model
|
||||
expected_chunks = [["content1"], ["content2"]]
|
||||
mock_embed.embed.assert_called_once_with(
|
||||
["content1", "content2"], model="test-model", input_type="document"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user