Compare commits

...

2 Commits

Author SHA1 Message Date
Daniel O'Connell
06eec621c1 shuffle around search 2025-06-28 04:33:27 +02:00
Daniel O'Connell
01ccea2733 add missing tests 2025-06-28 02:30:54 +02:00
14 changed files with 631 additions and 307 deletions

View File

@ -90,6 +90,7 @@ export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
<div className="search-result-card">
<h4>{title}</h4>
<Tag tags={tags} />
<Metadata metadata={metadata} />
<div className="image-container">
{mime_type && mime_type?.startsWith('image/') && <img src={`data:${mime_type};base64,${content}`} alt={title} className="search-result-image"/>}
</div>
@ -115,7 +116,7 @@ export const Metadata = ({ metadata }: { metadata: any }) => {
return (
<div className="metadata">
<ul>
{Object.entries(metadata).map(([key, value]) => (
{Object.entries(metadata).filter(([key, value]) => ![null, undefined].includes(value as any)).map(([key, value]) => (
<MetadataItem key={key} item={key} value={value} />
))}
</ul>
@ -154,19 +155,19 @@ export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
}
export const SearchResult = ({ result }: { result: SearchItem }) => {
if (result.mime_type.startsWith('image/')) {
if (result.mime_type?.startsWith('image/')) {
return <ImageResult {...result} />
}
if (result.mime_type.startsWith('text/markdown')) {
if (result.mime_type?.startsWith('text/markdown')) {
return <MarkdownResult {...result} />
}
if (result.mime_type.startsWith('text/')) {
if (result.mime_type?.startsWith('text/')) {
return <TextResult {...result} />
}
if (result.mime_type.startsWith('application/pdf')) {
if (result.mime_type?.startsWith('application/pdf')) {
return <PDFResult {...result} />
}
if (result.mime_type.startsWith('message/rfc822')) {
if (result.mime_type?.startsWith('message/rfc822')) {
return <EmailResult {...result} />
}
console.log(result)

View File

@ -109,14 +109,12 @@ async def search_knowledge_base(
search_filters = SearchFilters(**filters)
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
upload_data = extract.extract_text(query)
upload_data = extract.extract_text(query, skip_summary=True)
results = await search(
upload_data,
previews=previews,
modalities=modalities,
limit=limit,
min_text_score=0.4,
min_multimodal_score=0.25,
filters=search_filters,
)

View File

@ -1,4 +1,4 @@
from .search import search
from .utils import SearchResult, SearchFilters
from .types import SearchResult, SearchFilters
__all__ = ["search", "SearchResult", "SearchFilters"]

View File

@ -2,13 +2,15 @@
Search endpoints for the knowledge base API.
"""
import asyncio
from hashlib import sha256
import logging
import bm25s
import Stemmer
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
from memory.api.search.types import SearchFilters
from memory.common import extract
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, ConfidenceScore
@ -20,7 +22,7 @@ async def search_bm25(
modalities: set[str],
limit: int = 10,
filters: SearchFilters = SearchFilters(),
) -> list[tuple[SourceData, AnnotatedChunk]]:
) -> list[str]:
with make_session() as db:
items_query = db.query(Chunk.id, Chunk.content).filter(
Chunk.collection_name.in_(modalities),
@ -65,21 +67,18 @@ async def search_bm25(
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
for doc, score in zip(results[0], scores[0])
}
return list(item_scores.keys())
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
results = []
for chunk in chunks:
# Prefetch all needed source data while in session
source_data = SourceData.from_chunk(chunk)
annotated = AnnotatedChunk(
id=str(chunk.id),
score=item_scores[chunk.id],
metadata=chunk.source.as_payload(),
preview=None,
search_method="bm25",
)
results.append((source_data, annotated))
return results
async def search_bm25_chunks(
data: list[extract.DataChunk],
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = SearchFilters(),
timeout: int = 2,
) -> list[str]:
query = " ".join([c for chunk in data for c in chunk.data if isinstance(c, str)])
return await asyncio.wait_for(
search_bm25(query, modalities, limit, filters),
timeout,
)

View File

@ -1,65 +1,20 @@
import base64
import io
import logging
import asyncio
from typing import Any, Callable, Optional, cast
from typing import Any, Callable, cast
import qdrant_client
from PIL import Image
from qdrant_client.http import models as qdrant_models
from memory.common import embedding, extract, qdrant, settings
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
from memory.common import embedding, extract, qdrant
from memory.common.collections import (
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
from memory.api.search.types import SearchFilters
logger = logging.getLogger(__name__)
def annotated_chunk(
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
) -> tuple[SourceData, AnnotatedChunk]:
def serialize_item(item: bytes | str | Image.Image) -> str | None:
if not previews and not isinstance(item, str):
return None
if (
not previews
and isinstance(item, str)
and len(item) > settings.MAX_NON_PREVIEW_LENGTH
):
return item[: settings.MAX_NON_PREVIEW_LENGTH] + "..."
elif isinstance(item, str):
if len(item) > settings.MAX_PREVIEW_LENGTH:
return None
return item
if isinstance(item, Image.Image):
buffer = io.BytesIO()
format = item.format or "PNG"
item.save(buffer, format=format)
mime_type = f"image/{format.lower()}"
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
elif isinstance(item, bytes):
return base64.b64encode(item).decode("utf-8")
else:
raise ValueError(f"Unsupported item type: {type(item)}")
metadata = search_result.payload or {}
metadata = {
k: v
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
# Prefetch all needed source data while in session
return SourceData.from_chunk(chunk), AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
preview=serialize_item(chunk.data[0]) if chunk.data else None,
search_method="embeddings",
)
async def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk],
@ -178,15 +133,14 @@ def merge_filters(
return filters
async def search_embeddings(
async def search_chunks(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_score: float = 0.3,
filters: SearchFilters = {},
multimodal: bool = False,
) -> list[tuple[SourceData, AnnotatedChunk]]:
) -> list[str]:
"""
Search across knowledge base using text query and optional files.
@ -218,9 +172,38 @@ async def search_embeddings(
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
return [
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
return list(found_chunks.keys())
async def search_chunks_embeddings(
data: list[extract.DataChunk],
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = SearchFilters(),
timeout: int = 2,
) -> list[str]:
all_ids = await asyncio.gather(
asyncio.wait_for(
search_chunks(
data,
modalities & TEXT_COLLECTIONS,
limit,
0.4,
filters,
False,
),
timeout,
),
asyncio.wait_for(
search_chunks(
data,
modalities & MULTIMODAL_COLLECTIONS,
limit,
0.25,
filters,
True,
),
timeout,
),
)
return list({id for ids in all_ids for id in ids})

View File

@ -4,36 +4,70 @@ Search endpoints for the knowledge base API.
import asyncio
import logging
from collections import defaultdict
from typing import Optional
from sqlalchemy.orm import load_only
from memory.common import extract, settings
from memory.common.collections import (
ALL_COLLECTIONS,
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
from memory.api.search.embeddings import search_embeddings
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.common.collections import ALL_COLLECTIONS
from memory.api.search.embeddings import search_chunks_embeddings
if settings.ENABLE_BM25_SEARCH:
from memory.api.search.bm25 import search_bm25
from memory.api.search.bm25 import search_bm25_chunks
from memory.api.search.utils import (
SearchFilters,
SearchResult,
group_chunks,
with_timeout,
)
from memory.api.search.types import SearchFilters, SearchResult
logger = logging.getLogger(__name__)
async def search_chunks(
data: list[extract.DataChunk],
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = {},
timeout: int = 2,
) -> list[Chunk]:
funcs = [search_chunks_embeddings]
if settings.ENABLE_BM25_SEARCH:
funcs.append(search_bm25_chunks)
all_ids = await asyncio.gather(
*[func(data, modalities, limit, filters, timeout) for func in funcs]
)
all_ids = {id for ids in all_ids for id in ids}
with make_session() as db:
chunks = (
db.query(Chunk)
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
.filter(Chunk.id.in_(all_ids))
.all()
)
db.expunge_all()
return chunks
async def search_sources(
chunks: list[Chunk], previews: Optional[bool] = False
) -> list[SearchResult]:
by_source = defaultdict(list)
for chunk in chunks:
by_source[chunk.source_id].append(chunk)
with make_session() as db:
sources = db.query(SourceItem).filter(SourceItem.id.in_(by_source.keys())).all()
return [
SearchResult.from_source_item(source, by_source[source.id], previews)
for source in sources
]
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_text_score: float = 0.4,
min_multimodal_score: float = 0.25,
filters: SearchFilters = {},
timeout: int = 2,
) -> list[SearchResult]:
@ -50,56 +84,11 @@ async def search(
- List of search results sorted by score
"""
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
searches = []
if settings.ENABLE_EMBEDDING_SEARCH:
searches = [
with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & TEXT_COLLECTIONS,
limit,
min_text_score,
filters,
multimodal=False,
),
timeout,
),
with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & MULTIMODAL_COLLECTIONS,
limit,
min_multimodal_score,
filters,
multimodal=True,
),
timeout,
),
]
if settings.ENABLE_BM25_SEARCH:
searches.append(
with_timeout(
search_bm25(
" ".join(
[c for chunk in data for c in chunk.data if isinstance(c, str)]
),
modalities,
limit=limit,
filters=filters,
),
timeout,
)
)
search_results = await asyncio.gather(*searches, return_exceptions=False)
all_results = []
for results in search_results:
if len(all_results) >= limit:
break
all_results.extend(results)
results = group_chunks(all_results, previews or False)
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
chunks = await search_chunks(
data,
allowed_modalities,
limit,
filters,
timeout,
)
return await search_sources(chunks, previews)

View File

@ -0,0 +1,67 @@
from datetime import datetime
import logging
from typing import Optional, TypedDict, NotRequired, cast
from memory.common.db.models.source_item import SourceItem
from pydantic import BaseModel
from memory.common.db.models import Chunk
from memory.common import settings
logger = logging.getLogger(__name__)
class SearchResponse(BaseModel):
collection: str
results: list[dict]
def elide_content(content: str, max_length: int = 100) -> str:
if content and len(content) > max_length:
return content[:max_length] + "..."
return content
class SearchResult(BaseModel):
id: int
chunks: list[str]
size: int | None = None
mime_type: str | None = None
content: Optional[str | dict] = None
filename: Optional[str] = None
tags: list[str] | None = None
metadata: dict | None = None
created_at: datetime | None = None
@classmethod
def from_source_item(
cls, source: SourceItem, chunks: list[Chunk], previews: Optional[bool] = False
) -> "SearchResult":
metadata = source.display_contents or {}
metadata.pop("content", None)
chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
return cls(
id=cast(int, source.id),
size=cast(int, source.size),
mime_type=cast(str, source.mime_type),
chunks=[elide_content(str(chunk.content), chunk_size) for chunk in chunks],
content=elide_content(
cast(str, source.content),
settings.MAX_PREVIEW_LENGTH
if previews
else settings.MAX_NON_PREVIEW_LENGTH,
),
filename=cast(str, source.filename),
tags=cast(list[str], source.tags),
metadata=metadata,
created_at=cast(datetime | None, source.inserted_at),
)
class SearchFilters(TypedDict):
min_size: NotRequired[int]
max_size: NotRequired[int]
min_confidences: NotRequired[dict[str, float]]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]

View File

@ -1,140 +0,0 @@
import asyncio
import traceback
from datetime import datetime
import logging
from collections import defaultdict
from typing import Optional, TypedDict, NotRequired
from pydantic import BaseModel
from memory.common import settings
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
class AnnotatedChunk(BaseModel):
id: str
score: float
metadata: dict
preview: Optional[str | None] = None
search_method: str | None = None
class SourceData(BaseModel):
"""Holds source item data to avoid SQLAlchemy session issues"""
id: int
size: int | None
mime_type: str | None
filename: str | None
content_length: int
contents: dict | str | None
created_at: datetime | None
@staticmethod
def from_chunk(chunk: Chunk) -> "SourceData":
source = chunk.source
display_contents = source.display_contents or {}
return SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content_length=len(source.content) if source.content else 0,
contents={k: v for k, v in display_contents.items() if v is not None},
created_at=source.inserted_at,
)
class SearchResponse(BaseModel):
collection: str
results: list[dict]
class SearchResult(BaseModel):
id: int
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str | dict] = None
filename: Optional[str] = None
tags: list[str] | None = None
metadata: dict | None = None
created_at: datetime | None = None
class SearchFilters(TypedDict):
min_size: NotRequired[int]
max_size: NotRequired[int]
min_confidences: NotRequired[dict[str, float]]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]
async def with_timeout(
call, timeout: int = 2
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Run a function with a timeout.
Args:
call: The function to run
timeout: The timeout in seconds
"""
try:
return await asyncio.wait_for(call, timeout=timeout)
except TimeoutError:
logger.warning(f"Search timed out after {timeout}s")
return []
except Exception as e:
traceback.print_exc()
logger.error(f"Search failed: {e}")
return []
def group_chunks(
chunks: list[tuple[SourceData, AnnotatedChunk]], preview: bool = False
) -> list[SearchResult]:
items = defaultdict(list)
source_lookup = {}
for source, chunk in chunks:
items[source.id].append(chunk)
source_lookup[source.id] = source
def get_content(text: str | dict | None) -> str | dict | None:
if isinstance(text, str) and len(text) > settings.MAX_PREVIEW_LENGTH:
return None
return text
def make_result(source: SourceData, chunks: list[AnnotatedChunk]) -> SearchResult:
contents = source.contents or {}
tags = []
if isinstance(contents, dict):
tags = contents.pop("tags", [])
content = contents.pop("content", None)
else:
content = contents
contents = {}
return SearchResult(
id=source.id,
size=source.size or source.content_length,
mime_type=source.mime_type or "text/plain",
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=get_content(content),
tags=tags,
metadata=contents,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
created_at=source.created_at,
)
return [
make_result(source, chunks)
for source_id, chunks in items.items()
for source in [source_lookup[source_id]]
]

View File

@ -347,7 +347,7 @@ class SourceItem(Base):
collection_name=modality,
embedding_model=collections.collection_model(modality, text, images),
item_metadata=extract.merge_metadata(
self.as_payload(), data.metadata, metadata
cast(dict[str, Any], self.as_payload()), data.metadata, metadata
),
)
return chunk
@ -368,7 +368,7 @@ class SourceItem(Base):
return [cls.__tablename__]
@property
def display_contents(self) -> str | dict | None:
def display_contents(self) -> dict | None:
payload = self.as_payload()
payload.pop("source_id", None) # type: ignore
return {

View File

@ -15,7 +15,6 @@ from sqlalchemy import (
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from datetime import datetime
def hash_password(password: str) -> str:

View File

@ -135,7 +135,7 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240
# Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 8))
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
# API settings

View File

@ -547,3 +547,141 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session):
# Verify both the MailMessage and SourceItem records are deleted
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
@pytest.mark.parametrize(
"content,image_paths,expected_chunks",
[
("", [], 0), # Empty content returns empty list
(" \n ", [], 0), # Whitespace-only content returns empty list
("Short content", [], 1), # Short content returns just full_text chunk
("A" * 10, [], 1), # Very short content returns just full_text chunk
],
)
def test_chunk_mixed_basic_cases(tmp_path, content, image_paths, expected_chunks):
"""Test chunk_mixed function with basic cases"""
from memory.common.db.models.source_item import chunk_mixed
# Create test images if needed
actual_image_paths = []
for i, _ in enumerate(image_paths):
image_file = tmp_path / f"test{i}.png"
img = Image.new("RGB", (1, 1), color="red")
img.save(image_file)
actual_image_paths.append(image_file.name)
# Mock settings.FILE_STORAGE_DIR to point to tmp_path
with patch.object(settings, "FILE_STORAGE_DIR", tmp_path):
result = chunk_mixed(content, actual_image_paths)
assert len(result) == expected_chunks
def test_chunk_mixed_with_images(tmp_path):
"""Test chunk_mixed function with images"""
from memory.common.db.models.source_item import chunk_mixed
# Create test images
image1 = tmp_path / "image1.png"
image2 = tmp_path / "image2.jpg"
Image.new("RGB", (1, 1), color="red").save(image1)
Image.new("RGB", (1, 1), color="blue").save(image2)
content = "This content mentions image1.png and image2.jpg"
image_paths = [image1.name, image2.name]
with patch.object(settings, "FILE_STORAGE_DIR", tmp_path):
result = chunk_mixed(content, image_paths)
assert len(result) >= 1
# First chunk should contain the full text and images
assert content.strip() in result[0].data
assert len([d for d in result[0].data if isinstance(d, Image.Image)]) == 2
def test_chunk_mixed_long_content(tmp_path):
"""Test chunk_mixed function with long content that gets chunked"""
from memory.common.db.models.source_item import chunk_mixed
# Create long content
long_content = "Lorem ipsum dolor sit amet, " * 50 # About 150 words
# Mock the chunker functions to force chunking behavior
with (
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10),
patch.object(chunker, "approx_token_count", return_value=100),
): # Force it to be > 2 * 10
result = chunk_mixed(long_content, [])
# Should have multiple chunks: full_text + chunked pieces + summary
assert len(result) > 1
# First chunk should be full text
assert long_content.strip() in result[0].data
# Last chunk should be summary
# (we can't easily test the exact summary without mocking summarizer)
assert result[-1].data # Should have some data
@pytest.mark.parametrize(
"sha256_values,expected_committed",
[
([b"unique1", b"unique2", b"unique3"], 3), # All unique
([b"duplicate", b"duplicate", b"unique"], 2), # One duplicate pair
([b"same", b"same", b"same"], 1), # All duplicates
([b"dup1", b"dup1", b"dup2", b"dup2"], 2), # Two duplicate pairs
],
)
def test_handle_duplicate_sha256_behavior(
db_session: Session, sha256_values, expected_committed
):
"""Test that handle_duplicate_sha256 event listener prevents duplicate sha256 values"""
# Create SourceItems with the given sha256 values
items = []
for i, sha256 in enumerate(sha256_values):
item = SourceItem(sha256=sha256, content=f"test content {i}", modality="text")
items.append(item)
db_session.add(item)
# Commit should trigger the event listener
db_session.commit()
# Query how many items were actually committed
committed_count = db_session.query(SourceItem).count()
assert committed_count == expected_committed
# Verify all sha256 values in database are unique
sha256_in_db = [row[0] for row in db_session.query(SourceItem.sha256).all()]
assert len(sha256_in_db) == len(set(sha256_in_db)) # All unique
def test_handle_duplicate_sha256_with_existing_data(db_session: Session):
"""Test duplicate handling when items already exist in database"""
# Add initial items
existing_item = SourceItem(sha256=b"existing", content="original", modality="text")
db_session.add(existing_item)
db_session.commit()
# Try to add new items with same and different sha256
new_items = [
SourceItem(
sha256=b"existing", content="duplicate", modality="text"
), # Should be rejected
SourceItem(
sha256=b"new_unique", content="new content", modality="text"
), # Should be kept
]
for item in new_items:
db_session.add(item)
db_session.commit()
# Should have 2 items total (original + new unique)
assert db_session.query(SourceItem).count() == 2
# Original content should be preserved
existing_in_db = db_session.query(SourceItem).filter_by(sha256=b"existing").first()
assert existing_in_db is not None
assert str(existing_in_db.content) == "original" # Original should be preserved

View File

@ -0,0 +1,104 @@
import pytest
from memory.common.db.models.users import hash_password, verify_password
@pytest.mark.parametrize(
"password",
[
"simple_password",
"complex_P@ssw0rd!",
"very_long_password_with_many_characters_1234567890",
"",
"unicode_password_тест_😀",
"password with spaces",
],
)
def test_hash_password_format(password):
"""Test that hash_password returns correctly formatted hash"""
result = hash_password(password)
# Should be in format "salt:hash"
assert ":" in result
parts = result.split(":", 1)
assert len(parts) == 2
salt, hash_value = parts
# Salt should be 32 hex characters (16 bytes * 2)
assert len(salt) == 32
assert all(c in "0123456789abcdef" for c in salt)
# Hash should be 64 hex characters (SHA-256 = 32 bytes * 2)
assert len(hash_value) == 64
assert all(c in "0123456789abcdef" for c in hash_value)
def test_hash_password_uniqueness():
"""Test that same password generates different hashes due to random salt"""
password = "test_password"
hash1 = hash_password(password)
hash2 = hash_password(password)
# Different salts should produce different hashes
assert hash1 != hash2
# But both should verify correctly
assert verify_password(password, hash1)
assert verify_password(password, hash2)
@pytest.mark.parametrize(
"password,expected",
[
("correct_password", True),
("wrong_password", False),
("", False),
("CORRECT_PASSWORD", False), # Case sensitive
],
)
def test_verify_password_correctness(password, expected):
"""Test password verification with correct and incorrect passwords"""
correct_password = "correct_password"
password_hash = hash_password(correct_password)
result = verify_password(password, password_hash)
assert result == expected
@pytest.mark.parametrize(
"malformed_hash",
[
"invalid_format",
"no_colon_here",
":empty_salt",
"salt:", # Empty hash
"",
"too:many:colons:here",
"salt:invalid_hex_zzz",
"salt:too_short_hash",
],
)
def test_verify_password_malformed_hash(malformed_hash):
"""Test that verify_password handles malformed hashes gracefully"""
result = verify_password("any_password", malformed_hash)
assert result is False
@pytest.mark.parametrize(
"test_password",
[
"simple",
"complex_P@ssw0rd!123",
"",
"unicode_тест_😀",
"password with spaces and symbols !@#$%^&*()",
],
)
def test_hash_verify_roundtrip(test_password):
"""Test that hash and verify work correctly together"""
password_hash = hash_password(test_password)
# Correct password should verify
assert verify_password(test_password, password_hash)
# Wrong password should not verify
assert not verify_password(test_password + "_wrong", password_hash)

View File

@ -1,23 +1,30 @@
from unittest.mock import Mock
import pytest
from typing import cast
from PIL import Image
from memory.common import collections
from memory.common import collections, settings
from memory.common.embedding import (
as_string,
embed_chunks,
embed_mixed,
embed_text,
break_chunk,
embed_by_model,
)
from memory.common.extract import DataChunk
from memory.common.extract import DataChunk, MulitmodalChunk
from memory.common.db.models import Chunk, SourceItem
@pytest.fixture
def mock_embed(mock_voyage_client):
vectors = ([i] for i in range(1000))
def embed(texts, model, input_type):
def embed_func(texts, model, input_type):
return Mock(embeddings=[next(vectors) for _ in texts])
mock_voyage_client.embed = embed
mock_voyage_client.multimodal_embed = embed
mock_voyage_client.embed = Mock(side_effect=embed_func)
mock_voyage_client.multimodal_embed = Mock(side_effect=embed_func)
return mock_voyage_client
@ -52,3 +59,182 @@ def test_embed_text(mock_embed):
def test_embed_mixed(mock_embed):
items = [DataChunk(data=["text"])]
assert embed_mixed(items) == [[0]]
@pytest.mark.parametrize(
"input_data, expected_output",
[
("hello world", "hello world"),
(" hello world \n", "hello world"),
(
cast(list[MulitmodalChunk], ["first chunk", "second chunk", "third chunk"]),
"first chunk\nsecond chunk\nthird chunk",
),
(cast(list[MulitmodalChunk], []), ""),
(
cast(list[MulitmodalChunk], ["", "valid text", " ", "another text"]),
"valid text\n\nanother text",
),
],
)
def test_as_string_basic_cases(input_data, expected_output):
assert as_string(input_data) == expected_output
def test_as_string_with_nested_lists():
# This tests the recursive nature of as_string - kept separate due to different input type
chunks = [["nested", "items"], "single item"]
result = as_string(chunks)
assert result == "nested\nitems\nsingle item"
def test_embed_chunks_with_text_model(mock_embed):
chunks = cast(list[list[MulitmodalChunk]], [["text1"], ["text2"]])
result = embed_chunks(chunks, model=settings.TEXT_EMBEDDING_MODEL)
assert result == [[0], [1]]
mock_embed.embed.assert_called_once_with(
["text1", "text2"],
model=settings.TEXT_EMBEDDING_MODEL,
input_type="document",
)
def test_embed_chunks_with_mixed_model(mock_embed):
chunks = cast(list[list[MulitmodalChunk]], [["text with image"], ["another chunk"]])
result = embed_chunks(chunks, model=settings.MIXED_EMBEDDING_MODEL)
assert result == [[0], [1]]
mock_embed.multimodal_embed.assert_called_once_with(
chunks, model=settings.MIXED_EMBEDDING_MODEL, input_type="document"
)
def test_embed_chunks_with_query_input_type(mock_embed):
chunks = cast(list[list[MulitmodalChunk]], [["query text"]])
result = embed_chunks(chunks, input_type="query")
assert result == [[0]]
mock_embed.embed.assert_called_once_with(
["query text"], model=settings.TEXT_EMBEDDING_MODEL, input_type="query"
)
def test_embed_chunks_empty_list(mock_embed):
result = embed_chunks([])
assert result == []
@pytest.mark.parametrize(
"data, chunk_size, expected_result",
[
(["short text"], 100, ["short text"]),
(["some text content"], 200, ["some text content"]),
([], 100, []),
],
)
def test_break_chunk_simple_cases(data, chunk_size, expected_result):
chunk = DataChunk(data=data)
result = break_chunk(chunk, chunk_size=chunk_size)
assert result == expected_result
def test_break_chunk_with_long_text():
# Create text that will exceed chunk size
long_text = "word " * 200 # Should be much longer than default chunk size
chunk = DataChunk(data=[long_text])
result = break_chunk(chunk, chunk_size=50)
# Should be broken into multiple chunks
assert len(result) > 1
assert all(isinstance(item, str) for item in result)
def test_break_chunk_with_mixed_data_types():
# Mock image object
mock_image = Mock(spec=Image.Image)
chunk = DataChunk(data=["text content", mock_image])
result = break_chunk(chunk, chunk_size=100)
# Should have text chunks plus the original chunk (since it's not a string)
assert len(result) >= 2
assert any(isinstance(item, str) for item in result)
# The original chunk should be preserved when it contains mixed data
assert chunk in result
def test_embed_by_model_with_matching_chunks(mock_embed):
# Create mock chunks with specific embedding model
chunk1 = Mock(spec=Chunk)
chunk1.embedding_model = "test-model"
chunk1.chunks = ["chunk1 content"]
chunk2 = Mock(spec=Chunk)
chunk2.embedding_model = "test-model"
chunk2.chunks = ["chunk2 content"]
chunks = cast(list[Chunk], [chunk1, chunk2])
result = embed_by_model(chunks, "test-model")
assert len(result) == 2
assert chunk1.vector == [0]
assert chunk2.vector == [1]
assert result == [chunk1, chunk2]
def test_embed_by_model_with_no_matching_chunks(mock_embed):
chunk1 = Mock(spec=Chunk)
chunk1.embedding_model = "different-model"
# Ensure the chunk doesn't have a vector initially
del chunk1.vector
chunks = cast(list[Chunk], [chunk1])
result = embed_by_model(chunks, "test-model")
assert result == []
assert not hasattr(chunk1, "vector")
def test_embed_by_model_with_mixed_models(mock_embed):
chunk1 = Mock(spec=Chunk)
chunk1.embedding_model = "test-model"
chunk1.chunks = ["chunk1 content"]
chunk2 = Mock(spec=Chunk)
chunk2.embedding_model = "other-model"
chunk2.chunks = ["chunk2 content"]
chunk3 = Mock(spec=Chunk)
chunk3.embedding_model = "test-model"
chunk3.chunks = ["chunk3 content"]
chunks = cast(list[Chunk], [chunk1, chunk2, chunk3])
result = embed_by_model(chunks, "test-model")
assert len(result) == 2
assert chunk1 in result
assert chunk3 in result
assert chunk2 not in result
assert chunk1.vector == [0]
assert chunk3.vector == [1]
def test_embed_by_model_with_empty_chunks(mock_embed):
result = embed_by_model([], "test-model")
assert result == []
def test_embed_by_model_calls_embed_chunks_correctly(mock_embed):
chunk1 = Mock(spec=Chunk)
chunk1.embedding_model = "test-model"
chunk1.chunks = ["content1"]
chunk2 = Mock(spec=Chunk)
chunk2.embedding_model = "test-model"
chunk2.chunks = ["content2"]
chunks = cast(list[Chunk], [chunk1, chunk2])
embed_by_model(chunks, "test-model")
# Verify embed_chunks was called with the right model
expected_chunks = [["content1"], ["content2"]]
mock_embed.embed.assert_called_once_with(
["content1", "content2"], model="test-model", input_type="document"
)