diff --git a/src/memory/common/embedding.py b/src/memory/common/embedding.py index 9a4e752..3fc1f20 100644 --- a/src/memory/common/embedding.py +++ b/src/memory/common/embedding.py @@ -49,12 +49,13 @@ def embed_chunks( def break_chunk( chunk: extract.DataChunk, chunk_size: int = DEFAULT_CHUNK_TOKENS ) -> list[extract.MulitmodalChunk]: - result = [] + result: list[extract.MulitmodalChunk] = [] for c in chunk.data: if isinstance(c, str): result += chunk_text(c, chunk_size, OVERLAP_TOKENS) else: - result.append(chunk) + # Non-string items (e.g., images) are passed through directly + result.append(c) return result diff --git a/tests/memory/common/test_embedding.py b/tests/memory/common/test_embedding.py index 7e45a59..3407dce 100644 --- a/tests/memory/common/test_embedding.py +++ b/tests/memory/common/test_embedding.py @@ -153,11 +153,29 @@ def test_break_chunk_with_mixed_data_types(): chunk = DataChunk(data=["text content", mock_image]) result = break_chunk(chunk, chunk_size=100) - # Should have text chunks plus the original chunk (since it's not a string) + # Should have text chunks plus the image (non-string items are passed through) assert len(result) >= 2 assert any(isinstance(item, str) for item in result) - # The original chunk should be preserved when it contains mixed data - assert chunk in result + # The individual non-string item (image) should be in result, not the DataChunk + assert mock_image in result + # The DataChunk itself should NOT be in the result + assert chunk not in result + + +def test_break_chunk_preserves_non_string_items(): + """Non-string items (like images) should be preserved individually.""" + mock_image1 = Mock(spec=Image.Image) + mock_image2 = Mock(spec=Image.Image) + chunk = DataChunk(data=[mock_image1, "some text", mock_image2]) + result = break_chunk(chunk, chunk_size=100) + + # Both images should be in result + assert mock_image1 in result + assert mock_image2 in result + # Text should be chunked + assert "some text" in result + # Total should be 3 items (2 images + 1 short text) + assert len(result) == 3 def test_embed_by_model_with_matching_chunks(mock_embed):