mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-30 06:36:07 +02:00
Compare commits
2 Commits
50d0eb97db
...
06eec621c1
Author | SHA1 | Date | |
---|---|---|---|
![]() |
06eec621c1 | ||
![]() |
01ccea2733 |
@ -90,6 +90,7 @@ export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
|||||||
<div className="search-result-card">
|
<div className="search-result-card">
|
||||||
<h4>{title}</h4>
|
<h4>{title}</h4>
|
||||||
<Tag tags={tags} />
|
<Tag tags={tags} />
|
||||||
|
<Metadata metadata={metadata} />
|
||||||
<div className="image-container">
|
<div className="image-container">
|
||||||
{mime_type && mime_type?.startsWith('image/') && <img src={`data:${mime_type};base64,${content}`} alt={title} className="search-result-image"/>}
|
{mime_type && mime_type?.startsWith('image/') && <img src={`data:${mime_type};base64,${content}`} alt={title} className="search-result-image"/>}
|
||||||
</div>
|
</div>
|
||||||
@ -115,7 +116,7 @@ export const Metadata = ({ metadata }: { metadata: any }) => {
|
|||||||
return (
|
return (
|
||||||
<div className="metadata">
|
<div className="metadata">
|
||||||
<ul>
|
<ul>
|
||||||
{Object.entries(metadata).map(([key, value]) => (
|
{Object.entries(metadata).filter(([key, value]) => ![null, undefined].includes(value as any)).map(([key, value]) => (
|
||||||
<MetadataItem key={key} item={key} value={value} />
|
<MetadataItem key={key} item={key} value={value} />
|
||||||
))}
|
))}
|
||||||
</ul>
|
</ul>
|
||||||
@ -154,19 +155,19 @@ export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const SearchResult = ({ result }: { result: SearchItem }) => {
|
export const SearchResult = ({ result }: { result: SearchItem }) => {
|
||||||
if (result.mime_type.startsWith('image/')) {
|
if (result.mime_type?.startsWith('image/')) {
|
||||||
return <ImageResult {...result} />
|
return <ImageResult {...result} />
|
||||||
}
|
}
|
||||||
if (result.mime_type.startsWith('text/markdown')) {
|
if (result.mime_type?.startsWith('text/markdown')) {
|
||||||
return <MarkdownResult {...result} />
|
return <MarkdownResult {...result} />
|
||||||
}
|
}
|
||||||
if (result.mime_type.startsWith('text/')) {
|
if (result.mime_type?.startsWith('text/')) {
|
||||||
return <TextResult {...result} />
|
return <TextResult {...result} />
|
||||||
}
|
}
|
||||||
if (result.mime_type.startsWith('application/pdf')) {
|
if (result.mime_type?.startsWith('application/pdf')) {
|
||||||
return <PDFResult {...result} />
|
return <PDFResult {...result} />
|
||||||
}
|
}
|
||||||
if (result.mime_type.startsWith('message/rfc822')) {
|
if (result.mime_type?.startsWith('message/rfc822')) {
|
||||||
return <EmailResult {...result} />
|
return <EmailResult {...result} />
|
||||||
}
|
}
|
||||||
console.log(result)
|
console.log(result)
|
||||||
|
@ -109,14 +109,12 @@ async def search_knowledge_base(
|
|||||||
search_filters = SearchFilters(**filters)
|
search_filters = SearchFilters(**filters)
|
||||||
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
|
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
|
||||||
|
|
||||||
upload_data = extract.extract_text(query)
|
upload_data = extract.extract_text(query, skip_summary=True)
|
||||||
results = await search(
|
results = await search(
|
||||||
upload_data,
|
upload_data,
|
||||||
previews=previews,
|
previews=previews,
|
||||||
modalities=modalities,
|
modalities=modalities,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_text_score=0.4,
|
|
||||||
min_multimodal_score=0.25,
|
|
||||||
filters=search_filters,
|
filters=search_filters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .search import search
|
from .search import search
|
||||||
from .utils import SearchResult, SearchFilters
|
from .types import SearchResult, SearchFilters
|
||||||
|
|
||||||
__all__ = ["search", "SearchResult", "SearchFilters"]
|
__all__ = ["search", "SearchResult", "SearchFilters"]
|
||||||
|
@ -2,13 +2,15 @@
|
|||||||
Search endpoints for the knowledge base API.
|
Search endpoints for the knowledge base API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import bm25s
|
import bm25s
|
||||||
import Stemmer
|
import Stemmer
|
||||||
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
|
from memory.api.search.types import SearchFilters
|
||||||
|
|
||||||
|
from memory.common import extract
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import Chunk, ConfidenceScore
|
from memory.common.db.models import Chunk, ConfidenceScore
|
||||||
|
|
||||||
@ -20,7 +22,7 @@ async def search_bm25(
|
|||||||
modalities: set[str],
|
modalities: set[str],
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
filters: SearchFilters = SearchFilters(),
|
filters: SearchFilters = SearchFilters(),
|
||||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
) -> list[str]:
|
||||||
with make_session() as db:
|
with make_session() as db:
|
||||||
items_query = db.query(Chunk.id, Chunk.content).filter(
|
items_query = db.query(Chunk.id, Chunk.content).filter(
|
||||||
Chunk.collection_name.in_(modalities),
|
Chunk.collection_name.in_(modalities),
|
||||||
@ -65,21 +67,18 @@ async def search_bm25(
|
|||||||
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
|
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
|
||||||
for doc, score in zip(results[0], scores[0])
|
for doc, score in zip(results[0], scores[0])
|
||||||
}
|
}
|
||||||
|
return list(item_scores.keys())
|
||||||
|
|
||||||
with make_session() as db:
|
|
||||||
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
|
|
||||||
results = []
|
|
||||||
for chunk in chunks:
|
|
||||||
# Prefetch all needed source data while in session
|
|
||||||
source_data = SourceData.from_chunk(chunk)
|
|
||||||
|
|
||||||
annotated = AnnotatedChunk(
|
async def search_bm25_chunks(
|
||||||
id=str(chunk.id),
|
data: list[extract.DataChunk],
|
||||||
score=item_scores[chunk.id],
|
modalities: set[str] = set(),
|
||||||
metadata=chunk.source.as_payload(),
|
limit: int = 10,
|
||||||
preview=None,
|
filters: SearchFilters = SearchFilters(),
|
||||||
search_method="bm25",
|
timeout: int = 2,
|
||||||
)
|
) -> list[str]:
|
||||||
results.append((source_data, annotated))
|
query = " ".join([c for chunk in data for c in chunk.data if isinstance(c, str)])
|
||||||
|
return await asyncio.wait_for(
|
||||||
return results
|
search_bm25(query, modalities, limit, filters),
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
@ -1,65 +1,20 @@
|
|||||||
import base64
|
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Callable, Optional, cast
|
from typing import Any, Callable, cast
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from PIL import Image
|
|
||||||
from qdrant_client.http import models as qdrant_models
|
from qdrant_client.http import models as qdrant_models
|
||||||
|
|
||||||
from memory.common import embedding, extract, qdrant, settings
|
from memory.common import embedding, extract, qdrant
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.collections import (
|
||||||
from memory.common.db.models import Chunk
|
MULTIMODAL_COLLECTIONS,
|
||||||
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
|
TEXT_COLLECTIONS,
|
||||||
|
)
|
||||||
|
from memory.api.search.types import SearchFilters
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def annotated_chunk(
|
|
||||||
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
|
|
||||||
) -> tuple[SourceData, AnnotatedChunk]:
|
|
||||||
def serialize_item(item: bytes | str | Image.Image) -> str | None:
|
|
||||||
if not previews and not isinstance(item, str):
|
|
||||||
return None
|
|
||||||
if (
|
|
||||||
not previews
|
|
||||||
and isinstance(item, str)
|
|
||||||
and len(item) > settings.MAX_NON_PREVIEW_LENGTH
|
|
||||||
):
|
|
||||||
return item[: settings.MAX_NON_PREVIEW_LENGTH] + "..."
|
|
||||||
elif isinstance(item, str):
|
|
||||||
if len(item) > settings.MAX_PREVIEW_LENGTH:
|
|
||||||
return None
|
|
||||||
return item
|
|
||||||
if isinstance(item, Image.Image):
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
format = item.format or "PNG"
|
|
||||||
item.save(buffer, format=format)
|
|
||||||
mime_type = f"image/{format.lower()}"
|
|
||||||
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
|
||||||
elif isinstance(item, bytes):
|
|
||||||
return base64.b64encode(item).decode("utf-8")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported item type: {type(item)}")
|
|
||||||
|
|
||||||
metadata = search_result.payload or {}
|
|
||||||
metadata = {
|
|
||||||
k: v
|
|
||||||
for k, v in metadata.items()
|
|
||||||
if k not in ["content", "filename", "size", "content_type", "tags"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prefetch all needed source data while in session
|
|
||||||
return SourceData.from_chunk(chunk), AnnotatedChunk(
|
|
||||||
id=str(chunk.id),
|
|
||||||
score=search_result.score,
|
|
||||||
metadata=metadata,
|
|
||||||
preview=serialize_item(chunk.data[0]) if chunk.data else None,
|
|
||||||
search_method="embeddings",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
client: qdrant_client.QdrantClient,
|
client: qdrant_client.QdrantClient,
|
||||||
upload_data: list[extract.DataChunk],
|
upload_data: list[extract.DataChunk],
|
||||||
@ -178,15 +133,14 @@ def merge_filters(
|
|||||||
return filters
|
return filters
|
||||||
|
|
||||||
|
|
||||||
async def search_embeddings(
|
async def search_chunks(
|
||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
previews: Optional[bool] = False,
|
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
min_score: float = 0.3,
|
min_score: float = 0.3,
|
||||||
filters: SearchFilters = {},
|
filters: SearchFilters = {},
|
||||||
multimodal: bool = False,
|
multimodal: bool = False,
|
||||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Search across knowledge base using text query and optional files.
|
Search across knowledge base using text query and optional files.
|
||||||
|
|
||||||
@ -218,9 +172,38 @@ async def search_embeddings(
|
|||||||
found_chunks = {
|
found_chunks = {
|
||||||
str(r.id): r for results in search_results.values() for r in results
|
str(r.id): r for results in search_results.values() for r in results
|
||||||
}
|
}
|
||||||
with make_session() as db:
|
return list(found_chunks.keys())
|
||||||
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
|
|
||||||
return [
|
|
||||||
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
|
async def search_chunks_embeddings(
|
||||||
for chunk in chunks
|
data: list[extract.DataChunk],
|
||||||
]
|
modalities: set[str] = set(),
|
||||||
|
limit: int = 10,
|
||||||
|
filters: SearchFilters = SearchFilters(),
|
||||||
|
timeout: int = 2,
|
||||||
|
) -> list[str]:
|
||||||
|
all_ids = await asyncio.gather(
|
||||||
|
asyncio.wait_for(
|
||||||
|
search_chunks(
|
||||||
|
data,
|
||||||
|
modalities & TEXT_COLLECTIONS,
|
||||||
|
limit,
|
||||||
|
0.4,
|
||||||
|
filters,
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
timeout,
|
||||||
|
),
|
||||||
|
asyncio.wait_for(
|
||||||
|
search_chunks(
|
||||||
|
data,
|
||||||
|
modalities & MULTIMODAL_COLLECTIONS,
|
||||||
|
limit,
|
||||||
|
0.25,
|
||||||
|
filters,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
timeout,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return list({id for ids in all_ids for id in ids})
|
||||||
|
@ -4,36 +4,70 @@ Search endpoints for the knowledge base API.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import load_only
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
from memory.common.collections import (
|
from memory.common.db.connection import make_session
|
||||||
ALL_COLLECTIONS,
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
MULTIMODAL_COLLECTIONS,
|
from memory.common.collections import ALL_COLLECTIONS
|
||||||
TEXT_COLLECTIONS,
|
from memory.api.search.embeddings import search_chunks_embeddings
|
||||||
)
|
|
||||||
from memory.api.search.embeddings import search_embeddings
|
|
||||||
|
|
||||||
if settings.ENABLE_BM25_SEARCH:
|
if settings.ENABLE_BM25_SEARCH:
|
||||||
from memory.api.search.bm25 import search_bm25
|
from memory.api.search.bm25 import search_bm25_chunks
|
||||||
|
|
||||||
from memory.api.search.utils import (
|
from memory.api.search.types import SearchFilters, SearchResult
|
||||||
SearchFilters,
|
|
||||||
SearchResult,
|
|
||||||
group_chunks,
|
|
||||||
with_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def search_chunks(
|
||||||
|
data: list[extract.DataChunk],
|
||||||
|
modalities: set[str] = set(),
|
||||||
|
limit: int = 10,
|
||||||
|
filters: SearchFilters = {},
|
||||||
|
timeout: int = 2,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
funcs = [search_chunks_embeddings]
|
||||||
|
if settings.ENABLE_BM25_SEARCH:
|
||||||
|
funcs.append(search_bm25_chunks)
|
||||||
|
|
||||||
|
all_ids = await asyncio.gather(
|
||||||
|
*[func(data, modalities, limit, filters, timeout) for func in funcs]
|
||||||
|
)
|
||||||
|
all_ids = {id for ids in all_ids for id in ids}
|
||||||
|
|
||||||
|
with make_session() as db:
|
||||||
|
chunks = (
|
||||||
|
db.query(Chunk)
|
||||||
|
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
|
||||||
|
.filter(Chunk.id.in_(all_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
db.expunge_all()
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def search_sources(
|
||||||
|
chunks: list[Chunk], previews: Optional[bool] = False
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
by_source = defaultdict(list)
|
||||||
|
for chunk in chunks:
|
||||||
|
by_source[chunk.source_id].append(chunk)
|
||||||
|
|
||||||
|
with make_session() as db:
|
||||||
|
sources = db.query(SourceItem).filter(SourceItem.id.in_(by_source.keys())).all()
|
||||||
|
return [
|
||||||
|
SearchResult.from_source_item(source, by_source[source.id], previews)
|
||||||
|
for source in sources
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
previews: Optional[bool] = False,
|
previews: Optional[bool] = False,
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
min_text_score: float = 0.4,
|
|
||||||
min_multimodal_score: float = 0.25,
|
|
||||||
filters: SearchFilters = {},
|
filters: SearchFilters = {},
|
||||||
timeout: int = 2,
|
timeout: int = 2,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
@ -50,56 +84,11 @@ async def search(
|
|||||||
- List of search results sorted by score
|
- List of search results sorted by score
|
||||||
"""
|
"""
|
||||||
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
|
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
|
||||||
|
chunks = await search_chunks(
|
||||||
searches = []
|
data,
|
||||||
if settings.ENABLE_EMBEDDING_SEARCH:
|
allowed_modalities,
|
||||||
searches = [
|
limit,
|
||||||
with_timeout(
|
filters,
|
||||||
search_embeddings(
|
timeout,
|
||||||
data,
|
)
|
||||||
previews,
|
return await search_sources(chunks, previews)
|
||||||
allowed_modalities & TEXT_COLLECTIONS,
|
|
||||||
limit,
|
|
||||||
min_text_score,
|
|
||||||
filters,
|
|
||||||
multimodal=False,
|
|
||||||
),
|
|
||||||
timeout,
|
|
||||||
),
|
|
||||||
with_timeout(
|
|
||||||
search_embeddings(
|
|
||||||
data,
|
|
||||||
previews,
|
|
||||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
|
||||||
limit,
|
|
||||||
min_multimodal_score,
|
|
||||||
filters,
|
|
||||||
multimodal=True,
|
|
||||||
),
|
|
||||||
timeout,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
if settings.ENABLE_BM25_SEARCH:
|
|
||||||
searches.append(
|
|
||||||
with_timeout(
|
|
||||||
search_bm25(
|
|
||||||
" ".join(
|
|
||||||
[c for chunk in data for c in chunk.data if isinstance(c, str)]
|
|
||||||
),
|
|
||||||
modalities,
|
|
||||||
limit=limit,
|
|
||||||
filters=filters,
|
|
||||||
),
|
|
||||||
timeout,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
search_results = await asyncio.gather(*searches, return_exceptions=False)
|
|
||||||
all_results = []
|
|
||||||
for results in search_results:
|
|
||||||
if len(all_results) >= limit:
|
|
||||||
break
|
|
||||||
all_results.extend(results)
|
|
||||||
|
|
||||||
results = group_chunks(all_results, previews or False)
|
|
||||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
|
||||||
|
67
src/memory/api/search/types.py
Normal file
67
src/memory/api/search/types.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Optional, TypedDict, NotRequired, cast
|
||||||
|
|
||||||
|
from memory.common.db.models.source_item import SourceItem
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from memory.common.db.models import Chunk
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResponse(BaseModel):
|
||||||
|
collection: str
|
||||||
|
results: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
def elide_content(content: str, max_length: int = 100) -> str:
|
||||||
|
if content and len(content) > max_length:
|
||||||
|
return content[:max_length] + "..."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
id: int
|
||||||
|
chunks: list[str]
|
||||||
|
size: int | None = None
|
||||||
|
mime_type: str | None = None
|
||||||
|
content: Optional[str | dict] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
tags: list[str] | None = None
|
||||||
|
metadata: dict | None = None
|
||||||
|
created_at: datetime | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_source_item(
|
||||||
|
cls, source: SourceItem, chunks: list[Chunk], previews: Optional[bool] = False
|
||||||
|
) -> "SearchResult":
|
||||||
|
metadata = source.display_contents or {}
|
||||||
|
metadata.pop("content", None)
|
||||||
|
chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=cast(int, source.id),
|
||||||
|
size=cast(int, source.size),
|
||||||
|
mime_type=cast(str, source.mime_type),
|
||||||
|
chunks=[elide_content(str(chunk.content), chunk_size) for chunk in chunks],
|
||||||
|
content=elide_content(
|
||||||
|
cast(str, source.content),
|
||||||
|
settings.MAX_PREVIEW_LENGTH
|
||||||
|
if previews
|
||||||
|
else settings.MAX_NON_PREVIEW_LENGTH,
|
||||||
|
),
|
||||||
|
filename=cast(str, source.filename),
|
||||||
|
tags=cast(list[str], source.tags),
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=cast(datetime | None, source.inserted_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilters(TypedDict):
|
||||||
|
min_size: NotRequired[int]
|
||||||
|
max_size: NotRequired[int]
|
||||||
|
min_confidences: NotRequired[dict[str, float]]
|
||||||
|
observation_types: NotRequired[list[str] | None]
|
||||||
|
source_ids: NotRequired[list[int] | None]
|
@ -1,140 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Optional, TypedDict, NotRequired
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from memory.common import settings
|
|
||||||
from memory.common.db.models import Chunk
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotatedChunk(BaseModel):
|
|
||||||
id: str
|
|
||||||
score: float
|
|
||||||
metadata: dict
|
|
||||||
preview: Optional[str | None] = None
|
|
||||||
search_method: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SourceData(BaseModel):
|
|
||||||
"""Holds source item data to avoid SQLAlchemy session issues"""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
size: int | None
|
|
||||||
mime_type: str | None
|
|
||||||
filename: str | None
|
|
||||||
content_length: int
|
|
||||||
contents: dict | str | None
|
|
||||||
created_at: datetime | None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_chunk(chunk: Chunk) -> "SourceData":
|
|
||||||
source = chunk.source
|
|
||||||
display_contents = source.display_contents or {}
|
|
||||||
return SourceData(
|
|
||||||
id=source.id,
|
|
||||||
size=source.size,
|
|
||||||
mime_type=source.mime_type,
|
|
||||||
filename=source.filename,
|
|
||||||
content_length=len(source.content) if source.content else 0,
|
|
||||||
contents={k: v for k, v in display_contents.items() if v is not None},
|
|
||||||
created_at=source.inserted_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
|
||||||
collection: str
|
|
||||||
results: list[dict]
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResult(BaseModel):
|
|
||||||
id: int
|
|
||||||
size: int
|
|
||||||
mime_type: str
|
|
||||||
chunks: list[AnnotatedChunk]
|
|
||||||
content: Optional[str | dict] = None
|
|
||||||
filename: Optional[str] = None
|
|
||||||
tags: list[str] | None = None
|
|
||||||
metadata: dict | None = None
|
|
||||||
created_at: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SearchFilters(TypedDict):
|
|
||||||
min_size: NotRequired[int]
|
|
||||||
max_size: NotRequired[int]
|
|
||||||
min_confidences: NotRequired[dict[str, float]]
|
|
||||||
observation_types: NotRequired[list[str] | None]
|
|
||||||
source_ids: NotRequired[list[int] | None]
|
|
||||||
|
|
||||||
|
|
||||||
async def with_timeout(
|
|
||||||
call, timeout: int = 2
|
|
||||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
|
||||||
"""
|
|
||||||
Run a function with a timeout.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
call: The function to run
|
|
||||||
timeout: The timeout in seconds
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(call, timeout=timeout)
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning(f"Search timed out after {timeout}s")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
logger.error(f"Search failed: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def group_chunks(
|
|
||||||
chunks: list[tuple[SourceData, AnnotatedChunk]], preview: bool = False
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
items = defaultdict(list)
|
|
||||||
source_lookup = {}
|
|
||||||
|
|
||||||
for source, chunk in chunks:
|
|
||||||
items[source.id].append(chunk)
|
|
||||||
source_lookup[source.id] = source
|
|
||||||
|
|
||||||
def get_content(text: str | dict | None) -> str | dict | None:
|
|
||||||
if isinstance(text, str) and len(text) > settings.MAX_PREVIEW_LENGTH:
|
|
||||||
return None
|
|
||||||
return text
|
|
||||||
|
|
||||||
def make_result(source: SourceData, chunks: list[AnnotatedChunk]) -> SearchResult:
|
|
||||||
contents = source.contents or {}
|
|
||||||
tags = []
|
|
||||||
if isinstance(contents, dict):
|
|
||||||
tags = contents.pop("tags", [])
|
|
||||||
content = contents.pop("content", None)
|
|
||||||
else:
|
|
||||||
content = contents
|
|
||||||
contents = {}
|
|
||||||
|
|
||||||
return SearchResult(
|
|
||||||
id=source.id,
|
|
||||||
size=source.size or source.content_length,
|
|
||||||
mime_type=source.mime_type or "text/plain",
|
|
||||||
filename=source.filename
|
|
||||||
and source.filename.replace(
|
|
||||||
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
|
|
||||||
),
|
|
||||||
content=get_content(content),
|
|
||||||
tags=tags,
|
|
||||||
metadata=contents,
|
|
||||||
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
|
|
||||||
created_at=source.created_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
make_result(source, chunks)
|
|
||||||
for source_id, chunks in items.items()
|
|
||||||
for source in [source_lookup[source_id]]
|
|
||||||
]
|
|
@ -347,7 +347,7 @@ class SourceItem(Base):
|
|||||||
collection_name=modality,
|
collection_name=modality,
|
||||||
embedding_model=collections.collection_model(modality, text, images),
|
embedding_model=collections.collection_model(modality, text, images),
|
||||||
item_metadata=extract.merge_metadata(
|
item_metadata=extract.merge_metadata(
|
||||||
self.as_payload(), data.metadata, metadata
|
cast(dict[str, Any], self.as_payload()), data.metadata, metadata
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return chunk
|
return chunk
|
||||||
@ -368,7 +368,7 @@ class SourceItem(Base):
|
|||||||
return [cls.__tablename__]
|
return [cls.__tablename__]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_contents(self) -> str | dict | None:
|
def display_contents(self) -> dict | None:
|
||||||
payload = self.as_payload()
|
payload = self.as_payload()
|
||||||
payload.pop("source_id", None) # type: ignore
|
payload.pop("source_id", None) # type: ignore
|
||||||
return {
|
return {
|
||||||
|
@ -15,7 +15,6 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
|
@ -135,7 +135,7 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240
|
|||||||
# Search settings
|
# Search settings
|
||||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||||
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 8))
|
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
|
||||||
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
|
||||||
|
|
||||||
# API settings
|
# API settings
|
||||||
|
@ -547,3 +547,141 @@ def test_subclass_deletion_cascades_from_source_item(db_session: Session):
|
|||||||
# Verify both the MailMessage and SourceItem records are deleted
|
# Verify both the MailMessage and SourceItem records are deleted
|
||||||
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
|
assert db_session.query(MailMessage).filter_by(id=mail_message_id).first() is None
|
||||||
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
|
assert db_session.query(SourceItem).filter_by(id=source_item_id).first() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content,image_paths,expected_chunks",
|
||||||
|
[
|
||||||
|
("", [], 0), # Empty content returns empty list
|
||||||
|
(" \n ", [], 0), # Whitespace-only content returns empty list
|
||||||
|
("Short content", [], 1), # Short content returns just full_text chunk
|
||||||
|
("A" * 10, [], 1), # Very short content returns just full_text chunk
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_chunk_mixed_basic_cases(tmp_path, content, image_paths, expected_chunks):
|
||||||
|
"""Test chunk_mixed function with basic cases"""
|
||||||
|
from memory.common.db.models.source_item import chunk_mixed
|
||||||
|
|
||||||
|
# Create test images if needed
|
||||||
|
actual_image_paths = []
|
||||||
|
for i, _ in enumerate(image_paths):
|
||||||
|
image_file = tmp_path / f"test{i}.png"
|
||||||
|
img = Image.new("RGB", (1, 1), color="red")
|
||||||
|
img.save(image_file)
|
||||||
|
actual_image_paths.append(image_file.name)
|
||||||
|
|
||||||
|
# Mock settings.FILE_STORAGE_DIR to point to tmp_path
|
||||||
|
with patch.object(settings, "FILE_STORAGE_DIR", tmp_path):
|
||||||
|
result = chunk_mixed(content, actual_image_paths)
|
||||||
|
|
||||||
|
assert len(result) == expected_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_mixed_with_images(tmp_path):
|
||||||
|
"""Test chunk_mixed function with images"""
|
||||||
|
from memory.common.db.models.source_item import chunk_mixed
|
||||||
|
|
||||||
|
# Create test images
|
||||||
|
image1 = tmp_path / "image1.png"
|
||||||
|
image2 = tmp_path / "image2.jpg"
|
||||||
|
Image.new("RGB", (1, 1), color="red").save(image1)
|
||||||
|
Image.new("RGB", (1, 1), color="blue").save(image2)
|
||||||
|
|
||||||
|
content = "This content mentions image1.png and image2.jpg"
|
||||||
|
image_paths = [image1.name, image2.name]
|
||||||
|
|
||||||
|
with patch.object(settings, "FILE_STORAGE_DIR", tmp_path):
|
||||||
|
result = chunk_mixed(content, image_paths)
|
||||||
|
|
||||||
|
assert len(result) >= 1
|
||||||
|
# First chunk should contain the full text and images
|
||||||
|
assert content.strip() in result[0].data
|
||||||
|
assert len([d for d in result[0].data if isinstance(d, Image.Image)]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_mixed_long_content(tmp_path):
|
||||||
|
"""Test chunk_mixed function with long content that gets chunked"""
|
||||||
|
from memory.common.db.models.source_item import chunk_mixed
|
||||||
|
|
||||||
|
# Create long content
|
||||||
|
long_content = "Lorem ipsum dolor sit amet, " * 50 # About 150 words
|
||||||
|
|
||||||
|
# Mock the chunker functions to force chunking behavior
|
||||||
|
with (
|
||||||
|
patch.object(settings, "FILE_STORAGE_DIR", tmp_path),
|
||||||
|
patch.object(chunker, "DEFAULT_CHUNK_TOKENS", 10),
|
||||||
|
patch.object(chunker, "approx_token_count", return_value=100),
|
||||||
|
): # Force it to be > 2 * 10
|
||||||
|
result = chunk_mixed(long_content, [])
|
||||||
|
|
||||||
|
# Should have multiple chunks: full_text + chunked pieces + summary
|
||||||
|
assert len(result) > 1
|
||||||
|
|
||||||
|
# First chunk should be full text
|
||||||
|
assert long_content.strip() in result[0].data
|
||||||
|
|
||||||
|
# Last chunk should be summary
|
||||||
|
# (we can't easily test the exact summary without mocking summarizer)
|
||||||
|
assert result[-1].data # Should have some data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sha256_values,expected_committed",
|
||||||
|
[
|
||||||
|
([b"unique1", b"unique2", b"unique3"], 3), # All unique
|
||||||
|
([b"duplicate", b"duplicate", b"unique"], 2), # One duplicate pair
|
||||||
|
([b"same", b"same", b"same"], 1), # All duplicates
|
||||||
|
([b"dup1", b"dup1", b"dup2", b"dup2"], 2), # Two duplicate pairs
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_handle_duplicate_sha256_behavior(
|
||||||
|
db_session: Session, sha256_values, expected_committed
|
||||||
|
):
|
||||||
|
"""Test that handle_duplicate_sha256 event listener prevents duplicate sha256 values"""
|
||||||
|
# Create SourceItems with the given sha256 values
|
||||||
|
items = []
|
||||||
|
for i, sha256 in enumerate(sha256_values):
|
||||||
|
item = SourceItem(sha256=sha256, content=f"test content {i}", modality="text")
|
||||||
|
items.append(item)
|
||||||
|
db_session.add(item)
|
||||||
|
|
||||||
|
# Commit should trigger the event listener
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Query how many items were actually committed
|
||||||
|
committed_count = db_session.query(SourceItem).count()
|
||||||
|
assert committed_count == expected_committed
|
||||||
|
|
||||||
|
# Verify all sha256 values in database are unique
|
||||||
|
sha256_in_db = [row[0] for row in db_session.query(SourceItem.sha256).all()]
|
||||||
|
assert len(sha256_in_db) == len(set(sha256_in_db)) # All unique
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_duplicate_sha256_with_existing_data(db_session: Session):
|
||||||
|
"""Test duplicate handling when items already exist in database"""
|
||||||
|
# Add initial items
|
||||||
|
existing_item = SourceItem(sha256=b"existing", content="original", modality="text")
|
||||||
|
db_session.add(existing_item)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Try to add new items with same and different sha256
|
||||||
|
new_items = [
|
||||||
|
SourceItem(
|
||||||
|
sha256=b"existing", content="duplicate", modality="text"
|
||||||
|
), # Should be rejected
|
||||||
|
SourceItem(
|
||||||
|
sha256=b"new_unique", content="new content", modality="text"
|
||||||
|
), # Should be kept
|
||||||
|
]
|
||||||
|
for item in new_items:
|
||||||
|
db_session.add(item)
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Should have 2 items total (original + new unique)
|
||||||
|
assert db_session.query(SourceItem).count() == 2
|
||||||
|
|
||||||
|
# Original content should be preserved
|
||||||
|
existing_in_db = db_session.query(SourceItem).filter_by(sha256=b"existing").first()
|
||||||
|
assert existing_in_db is not None
|
||||||
|
assert str(existing_in_db.content) == "original" # Original should be preserved
|
||||||
|
104
tests/memory/common/db/models/test_users.py
Normal file
104
tests/memory/common/db/models/test_users.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import pytest
|
||||||
|
from memory.common.db.models.users import hash_password, verify_password
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"password",
|
||||||
|
[
|
||||||
|
"simple_password",
|
||||||
|
"complex_P@ssw0rd!",
|
||||||
|
"very_long_password_with_many_characters_1234567890",
|
||||||
|
"",
|
||||||
|
"unicode_password_тест_😀",
|
||||||
|
"password with spaces",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_hash_password_format(password):
|
||||||
|
"""Test that hash_password returns correctly formatted hash"""
|
||||||
|
result = hash_password(password)
|
||||||
|
|
||||||
|
# Should be in format "salt:hash"
|
||||||
|
assert ":" in result
|
||||||
|
parts = result.split(":", 1)
|
||||||
|
assert len(parts) == 2
|
||||||
|
|
||||||
|
salt, hash_value = parts
|
||||||
|
# Salt should be 32 hex characters (16 bytes * 2)
|
||||||
|
assert len(salt) == 32
|
||||||
|
assert all(c in "0123456789abcdef" for c in salt)
|
||||||
|
|
||||||
|
# Hash should be 64 hex characters (SHA-256 = 32 bytes * 2)
|
||||||
|
assert len(hash_value) == 64
|
||||||
|
assert all(c in "0123456789abcdef" for c in hash_value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_password_uniqueness():
|
||||||
|
"""Test that same password generates different hashes due to random salt"""
|
||||||
|
password = "test_password"
|
||||||
|
hash1 = hash_password(password)
|
||||||
|
hash2 = hash_password(password)
|
||||||
|
|
||||||
|
# Different salts should produce different hashes
|
||||||
|
assert hash1 != hash2
|
||||||
|
|
||||||
|
# But both should verify correctly
|
||||||
|
assert verify_password(password, hash1)
|
||||||
|
assert verify_password(password, hash2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"password,expected",
|
||||||
|
[
|
||||||
|
("correct_password", True),
|
||||||
|
("wrong_password", False),
|
||||||
|
("", False),
|
||||||
|
("CORRECT_PASSWORD", False), # Case sensitive
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_verify_password_correctness(password, expected):
|
||||||
|
"""Test password verification with correct and incorrect passwords"""
|
||||||
|
correct_password = "correct_password"
|
||||||
|
password_hash = hash_password(correct_password)
|
||||||
|
|
||||||
|
result = verify_password(password, password_hash)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"malformed_hash",
|
||||||
|
[
|
||||||
|
"invalid_format",
|
||||||
|
"no_colon_here",
|
||||||
|
":empty_salt",
|
||||||
|
"salt:", # Empty hash
|
||||||
|
"",
|
||||||
|
"too:many:colons:here",
|
||||||
|
"salt:invalid_hex_zzz",
|
||||||
|
"salt:too_short_hash",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_verify_password_malformed_hash(malformed_hash):
|
||||||
|
"""Test that verify_password handles malformed hashes gracefully"""
|
||||||
|
result = verify_password("any_password", malformed_hash)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_password",
|
||||||
|
[
|
||||||
|
"simple",
|
||||||
|
"complex_P@ssw0rd!123",
|
||||||
|
"",
|
||||||
|
"unicode_тест_😀",
|
||||||
|
"password with spaces and symbols !@#$%^&*()",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_hash_verify_roundtrip(test_password):
|
||||||
|
"""Test that hash and verify work correctly together"""
|
||||||
|
password_hash = hash_password(test_password)
|
||||||
|
|
||||||
|
# Correct password should verify
|
||||||
|
assert verify_password(test_password, password_hash)
|
||||||
|
|
||||||
|
# Wrong password should not verify
|
||||||
|
assert not verify_password(test_password + "_wrong", password_hash)
|
@ -1,23 +1,30 @@
|
|||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing import cast
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from memory.common import collections
|
from memory.common import collections, settings
|
||||||
from memory.common.embedding import (
|
from memory.common.embedding import (
|
||||||
|
as_string,
|
||||||
|
embed_chunks,
|
||||||
embed_mixed,
|
embed_mixed,
|
||||||
embed_text,
|
embed_text,
|
||||||
|
break_chunk,
|
||||||
|
embed_by_model,
|
||||||
)
|
)
|
||||||
from memory.common.extract import DataChunk
|
from memory.common.extract import DataChunk, MulitmodalChunk
|
||||||
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_embed(mock_voyage_client):
|
def mock_embed(mock_voyage_client):
|
||||||
vectors = ([i] for i in range(1000))
|
vectors = ([i] for i in range(1000))
|
||||||
|
|
||||||
def embed(texts, model, input_type):
|
def embed_func(texts, model, input_type):
|
||||||
return Mock(embeddings=[next(vectors) for _ in texts])
|
return Mock(embeddings=[next(vectors) for _ in texts])
|
||||||
|
|
||||||
mock_voyage_client.embed = embed
|
mock_voyage_client.embed = Mock(side_effect=embed_func)
|
||||||
mock_voyage_client.multimodal_embed = embed
|
mock_voyage_client.multimodal_embed = Mock(side_effect=embed_func)
|
||||||
|
|
||||||
return mock_voyage_client
|
return mock_voyage_client
|
||||||
|
|
||||||
@ -52,3 +59,182 @@ def test_embed_text(mock_embed):
|
|||||||
def test_embed_mixed(mock_embed):
|
def test_embed_mixed(mock_embed):
|
||||||
items = [DataChunk(data=["text"])]
|
items = [DataChunk(data=["text"])]
|
||||||
assert embed_mixed(items) == [[0]]
|
assert embed_mixed(items) == [[0]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_data, expected_output",
|
||||||
|
[
|
||||||
|
("hello world", "hello world"),
|
||||||
|
(" hello world \n", "hello world"),
|
||||||
|
(
|
||||||
|
cast(list[MulitmodalChunk], ["first chunk", "second chunk", "third chunk"]),
|
||||||
|
"first chunk\nsecond chunk\nthird chunk",
|
||||||
|
),
|
||||||
|
(cast(list[MulitmodalChunk], []), ""),
|
||||||
|
(
|
||||||
|
cast(list[MulitmodalChunk], ["", "valid text", " ", "another text"]),
|
||||||
|
"valid text\n\nanother text",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_as_string_basic_cases(input_data, expected_output):
|
||||||
|
assert as_string(input_data) == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_as_string_with_nested_lists():
|
||||||
|
# This tests the recursive nature of as_string - kept separate due to different input type
|
||||||
|
chunks = [["nested", "items"], "single item"]
|
||||||
|
result = as_string(chunks)
|
||||||
|
assert result == "nested\nitems\nsingle item"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_chunks_with_text_model(mock_embed):
|
||||||
|
chunks = cast(list[list[MulitmodalChunk]], [["text1"], ["text2"]])
|
||||||
|
result = embed_chunks(chunks, model=settings.TEXT_EMBEDDING_MODEL)
|
||||||
|
assert result == [[0], [1]]
|
||||||
|
mock_embed.embed.assert_called_once_with(
|
||||||
|
["text1", "text2"],
|
||||||
|
model=settings.TEXT_EMBEDDING_MODEL,
|
||||||
|
input_type="document",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_chunks_with_mixed_model(mock_embed):
|
||||||
|
chunks = cast(list[list[MulitmodalChunk]], [["text with image"], ["another chunk"]])
|
||||||
|
result = embed_chunks(chunks, model=settings.MIXED_EMBEDDING_MODEL)
|
||||||
|
assert result == [[0], [1]]
|
||||||
|
mock_embed.multimodal_embed.assert_called_once_with(
|
||||||
|
chunks, model=settings.MIXED_EMBEDDING_MODEL, input_type="document"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_chunks_with_query_input_type(mock_embed):
|
||||||
|
chunks = cast(list[list[MulitmodalChunk]], [["query text"]])
|
||||||
|
result = embed_chunks(chunks, input_type="query")
|
||||||
|
assert result == [[0]]
|
||||||
|
mock_embed.embed.assert_called_once_with(
|
||||||
|
["query text"], model=settings.TEXT_EMBEDDING_MODEL, input_type="query"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_chunks_empty_list(mock_embed):
|
||||||
|
result = embed_chunks([])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data, chunk_size, expected_result",
|
||||||
|
[
|
||||||
|
(["short text"], 100, ["short text"]),
|
||||||
|
(["some text content"], 200, ["some text content"]),
|
||||||
|
([], 100, []),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_break_chunk_simple_cases(data, chunk_size, expected_result):
|
||||||
|
chunk = DataChunk(data=data)
|
||||||
|
result = break_chunk(chunk, chunk_size=chunk_size)
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_break_chunk_with_long_text():
|
||||||
|
# Create text that will exceed chunk size
|
||||||
|
long_text = "word " * 200 # Should be much longer than default chunk size
|
||||||
|
chunk = DataChunk(data=[long_text])
|
||||||
|
result = break_chunk(chunk, chunk_size=50)
|
||||||
|
|
||||||
|
# Should be broken into multiple chunks
|
||||||
|
assert len(result) > 1
|
||||||
|
assert all(isinstance(item, str) for item in result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_break_chunk_with_mixed_data_types():
|
||||||
|
# Mock image object
|
||||||
|
mock_image = Mock(spec=Image.Image)
|
||||||
|
chunk = DataChunk(data=["text content", mock_image])
|
||||||
|
result = break_chunk(chunk, chunk_size=100)
|
||||||
|
|
||||||
|
# Should have text chunks plus the original chunk (since it's not a string)
|
||||||
|
assert len(result) >= 2
|
||||||
|
assert any(isinstance(item, str) for item in result)
|
||||||
|
# The original chunk should be preserved when it contains mixed data
|
||||||
|
assert chunk in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_by_model_with_matching_chunks(mock_embed):
|
||||||
|
# Create mock chunks with specific embedding model
|
||||||
|
chunk1 = Mock(spec=Chunk)
|
||||||
|
chunk1.embedding_model = "test-model"
|
||||||
|
chunk1.chunks = ["chunk1 content"]
|
||||||
|
|
||||||
|
chunk2 = Mock(spec=Chunk)
|
||||||
|
chunk2.embedding_model = "test-model"
|
||||||
|
chunk2.chunks = ["chunk2 content"]
|
||||||
|
|
||||||
|
chunks = cast(list[Chunk], [chunk1, chunk2])
|
||||||
|
result = embed_by_model(chunks, "test-model")
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert chunk1.vector == [0]
|
||||||
|
assert chunk2.vector == [1]
|
||||||
|
assert result == [chunk1, chunk2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_by_model_with_no_matching_chunks(mock_embed):
|
||||||
|
chunk1 = Mock(spec=Chunk)
|
||||||
|
chunk1.embedding_model = "different-model"
|
||||||
|
# Ensure the chunk doesn't have a vector initially
|
||||||
|
del chunk1.vector
|
||||||
|
|
||||||
|
chunks = cast(list[Chunk], [chunk1])
|
||||||
|
result = embed_by_model(chunks, "test-model")
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
assert not hasattr(chunk1, "vector")
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_by_model_with_mixed_models(mock_embed):
|
||||||
|
chunk1 = Mock(spec=Chunk)
|
||||||
|
chunk1.embedding_model = "test-model"
|
||||||
|
chunk1.chunks = ["chunk1 content"]
|
||||||
|
|
||||||
|
chunk2 = Mock(spec=Chunk)
|
||||||
|
chunk2.embedding_model = "other-model"
|
||||||
|
chunk2.chunks = ["chunk2 content"]
|
||||||
|
|
||||||
|
chunk3 = Mock(spec=Chunk)
|
||||||
|
chunk3.embedding_model = "test-model"
|
||||||
|
chunk3.chunks = ["chunk3 content"]
|
||||||
|
|
||||||
|
chunks = cast(list[Chunk], [chunk1, chunk2, chunk3])
|
||||||
|
result = embed_by_model(chunks, "test-model")
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert chunk1 in result
|
||||||
|
assert chunk3 in result
|
||||||
|
assert chunk2 not in result
|
||||||
|
assert chunk1.vector == [0]
|
||||||
|
assert chunk3.vector == [1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_by_model_with_empty_chunks(mock_embed):
|
||||||
|
result = embed_by_model([], "test-model")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_by_model_calls_embed_chunks_correctly(mock_embed):
|
||||||
|
chunk1 = Mock(spec=Chunk)
|
||||||
|
chunk1.embedding_model = "test-model"
|
||||||
|
chunk1.chunks = ["content1"]
|
||||||
|
|
||||||
|
chunk2 = Mock(spec=Chunk)
|
||||||
|
chunk2.embedding_model = "test-model"
|
||||||
|
chunk2.chunks = ["content2"]
|
||||||
|
|
||||||
|
chunks = cast(list[Chunk], [chunk1, chunk2])
|
||||||
|
embed_by_model(chunks, "test-model")
|
||||||
|
|
||||||
|
# Verify embed_chunks was called with the right model
|
||||||
|
expected_chunks = [["content1"], ["content2"]]
|
||||||
|
mock_embed.embed.assert_called_once_with(
|
||||||
|
["content1", "content2"], model="test-model", input_type="document"
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user