shuffle around search

This commit is contained in:
Daniel O'Connell 2025-06-28 04:33:27 +02:00
parent 01ccea2733
commit 06eec621c1
10 changed files with 197 additions and 300 deletions

View File

@ -90,6 +90,7 @@ export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
<div className="search-result-card">
<h4>{title}</h4>
<Tag tags={tags} />
<Metadata metadata={metadata} />
<div className="image-container">
{mime_type && mime_type?.startsWith('image/') && <img src={`data:${mime_type};base64,${content}`} alt={title} className="search-result-image"/>}
</div>
@ -115,7 +116,7 @@ export const Metadata = ({ metadata }: { metadata: any }) => {
return (
<div className="metadata">
<ul>
{Object.entries(metadata).map(([key, value]) => (
{Object.entries(metadata).filter(([key, value]) => ![null, undefined].includes(value as any)).map(([key, value]) => (
<MetadataItem key={key} item={key} value={value} />
))}
</ul>
@ -154,19 +155,19 @@ export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
}
export const SearchResult = ({ result }: { result: SearchItem }) => {
if (result.mime_type.startsWith('image/')) {
if (result.mime_type?.startsWith('image/')) {
return <ImageResult {...result} />
}
if (result.mime_type.startsWith('text/markdown')) {
if (result.mime_type?.startsWith('text/markdown')) {
return <MarkdownResult {...result} />
}
if (result.mime_type.startsWith('text/')) {
if (result.mime_type?.startsWith('text/')) {
return <TextResult {...result} />
}
if (result.mime_type.startsWith('application/pdf')) {
if (result.mime_type?.startsWith('application/pdf')) {
return <PDFResult {...result} />
}
if (result.mime_type.startsWith('message/rfc822')) {
if (result.mime_type?.startsWith('message/rfc822')) {
return <EmailResult {...result} />
}
console.log(result)

View File

@ -109,14 +109,12 @@ async def search_knowledge_base(
search_filters = SearchFilters(**filters)
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
upload_data = extract.extract_text(query)
upload_data = extract.extract_text(query, skip_summary=True)
results = await search(
upload_data,
previews=previews,
modalities=modalities,
limit=limit,
min_text_score=0.4,
min_multimodal_score=0.25,
filters=search_filters,
)

View File

@ -1,4 +1,4 @@
from .search import search
from .utils import SearchResult, SearchFilters
from .types import SearchResult, SearchFilters
__all__ = ["search", "SearchResult", "SearchFilters"]

View File

@ -2,13 +2,15 @@
Search endpoints for the knowledge base API.
"""
import asyncio
from hashlib import sha256
import logging
import bm25s
import Stemmer
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
from memory.api.search.types import SearchFilters
from memory.common import extract
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, ConfidenceScore
@ -20,7 +22,7 @@ async def search_bm25(
modalities: set[str],
limit: int = 10,
filters: SearchFilters = SearchFilters(),
) -> list[tuple[SourceData, AnnotatedChunk]]:
) -> list[str]:
with make_session() as db:
items_query = db.query(Chunk.id, Chunk.content).filter(
Chunk.collection_name.in_(modalities),
@ -65,21 +67,18 @@ async def search_bm25(
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
for doc, score in zip(results[0], scores[0])
}
return list(item_scores.keys())
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
results = []
for chunk in chunks:
# Prefetch all needed source data while in session
source_data = SourceData.from_chunk(chunk)
annotated = AnnotatedChunk(
id=str(chunk.id),
score=item_scores[chunk.id],
metadata=chunk.source.as_payload(),
preview=None,
search_method="bm25",
async def search_bm25_chunks(
data: list[extract.DataChunk],
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = SearchFilters(),
timeout: int = 2,
) -> list[str]:
query = " ".join([c for chunk in data for c in chunk.data if isinstance(c, str)])
return await asyncio.wait_for(
search_bm25(query, modalities, limit, filters),
timeout,
)
results.append((source_data, annotated))
return results

View File

@ -1,65 +1,20 @@
import base64
import io
import logging
import asyncio
from typing import Any, Callable, Optional, cast
from typing import Any, Callable, cast
import qdrant_client
from PIL import Image
from qdrant_client.http import models as qdrant_models
from memory.common import embedding, extract, qdrant, settings
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk
from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
from memory.common import embedding, extract, qdrant
from memory.common.collections import (
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
from memory.api.search.types import SearchFilters
logger = logging.getLogger(__name__)
def annotated_chunk(
chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
) -> tuple[SourceData, AnnotatedChunk]:
def serialize_item(item: bytes | str | Image.Image) -> str | None:
if not previews and not isinstance(item, str):
return None
if (
not previews
and isinstance(item, str)
and len(item) > settings.MAX_NON_PREVIEW_LENGTH
):
return item[: settings.MAX_NON_PREVIEW_LENGTH] + "..."
elif isinstance(item, str):
if len(item) > settings.MAX_PREVIEW_LENGTH:
return None
return item
if isinstance(item, Image.Image):
buffer = io.BytesIO()
format = item.format or "PNG"
item.save(buffer, format=format)
mime_type = f"image/{format.lower()}"
return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
elif isinstance(item, bytes):
return base64.b64encode(item).decode("utf-8")
else:
raise ValueError(f"Unsupported item type: {type(item)}")
metadata = search_result.payload or {}
metadata = {
k: v
for k, v in metadata.items()
if k not in ["content", "filename", "size", "content_type", "tags"]
}
# Prefetch all needed source data while in session
return SourceData.from_chunk(chunk), AnnotatedChunk(
id=str(chunk.id),
score=search_result.score,
metadata=metadata,
preview=serialize_item(chunk.data[0]) if chunk.data else None,
search_method="embeddings",
)
async def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk],
@ -178,15 +133,14 @@ def merge_filters(
return filters
async def search_embeddings(
async def search_chunks(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_score: float = 0.3,
filters: SearchFilters = {},
multimodal: bool = False,
) -> list[tuple[SourceData, AnnotatedChunk]]:
) -> list[str]:
"""
Search across knowledge base using text query and optional files.
@ -218,9 +172,38 @@ async def search_embeddings(
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
with make_session() as db:
chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
return [
annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
for chunk in chunks
]
return list(found_chunks.keys())
async def search_chunks_embeddings(
data: list[extract.DataChunk],
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = SearchFilters(),
timeout: int = 2,
) -> list[str]:
all_ids = await asyncio.gather(
asyncio.wait_for(
search_chunks(
data,
modalities & TEXT_COLLECTIONS,
limit,
0.4,
filters,
False,
),
timeout,
),
asyncio.wait_for(
search_chunks(
data,
modalities & MULTIMODAL_COLLECTIONS,
limit,
0.25,
filters,
True,
),
timeout,
),
)
return list({id for ids in all_ids for id in ids})

View File

@ -4,36 +4,70 @@ Search endpoints for the knowledge base API.
import asyncio
import logging
from collections import defaultdict
from typing import Optional
from sqlalchemy.orm import load_only
from memory.common import extract, settings
from memory.common.collections import (
ALL_COLLECTIONS,
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
from memory.api.search.embeddings import search_embeddings
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, SourceItem
from memory.common.collections import ALL_COLLECTIONS
from memory.api.search.embeddings import search_chunks_embeddings
if settings.ENABLE_BM25_SEARCH:
from memory.api.search.bm25 import search_bm25
from memory.api.search.bm25 import search_bm25_chunks
from memory.api.search.utils import (
SearchFilters,
SearchResult,
group_chunks,
with_timeout,
)
from memory.api.search.types import SearchFilters, SearchResult
logger = logging.getLogger(__name__)
async def search_chunks(
data: list[extract.DataChunk],
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = {},
timeout: int = 2,
) -> list[Chunk]:
funcs = [search_chunks_embeddings]
if settings.ENABLE_BM25_SEARCH:
funcs.append(search_bm25_chunks)
all_ids = await asyncio.gather(
*[func(data, modalities, limit, filters, timeout) for func in funcs]
)
all_ids = {id for ids in all_ids for id in ids}
with make_session() as db:
chunks = (
db.query(Chunk)
.options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
.filter(Chunk.id.in_(all_ids))
.all()
)
db.expunge_all()
return chunks
async def search_sources(
chunks: list[Chunk], previews: Optional[bool] = False
) -> list[SearchResult]:
by_source = defaultdict(list)
for chunk in chunks:
by_source[chunk.source_id].append(chunk)
with make_session() as db:
sources = db.query(SourceItem).filter(SourceItem.id.in_(by_source.keys())).all()
return [
SearchResult.from_source_item(source, by_source[source.id], previews)
for source in sources
]
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_text_score: float = 0.4,
min_multimodal_score: float = 0.25,
filters: SearchFilters = {},
timeout: int = 2,
) -> list[SearchResult]:
@ -50,56 +84,11 @@ async def search(
- List of search results sorted by score
"""
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
searches = []
if settings.ENABLE_EMBEDDING_SEARCH:
searches = [
with_timeout(
search_embeddings(
chunks = await search_chunks(
data,
previews,
allowed_modalities & TEXT_COLLECTIONS,
allowed_modalities,
limit,
min_text_score,
filters,
multimodal=False,
),
timeout,
),
with_timeout(
search_embeddings(
data,
previews,
allowed_modalities & MULTIMODAL_COLLECTIONS,
limit,
min_multimodal_score,
filters,
multimodal=True,
),
timeout,
),
]
if settings.ENABLE_BM25_SEARCH:
searches.append(
with_timeout(
search_bm25(
" ".join(
[c for chunk in data for c in chunk.data if isinstance(c, str)]
),
modalities,
limit=limit,
filters=filters,
),
timeout,
)
)
search_results = await asyncio.gather(*searches, return_exceptions=False)
all_results = []
for results in search_results:
if len(all_results) >= limit:
break
all_results.extend(results)
results = group_chunks(all_results, previews or False)
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
return await search_sources(chunks, previews)

View File

@ -0,0 +1,67 @@
from datetime import datetime
import logging
from typing import Optional, TypedDict, NotRequired, cast
from memory.common.db.models.source_item import SourceItem
from pydantic import BaseModel
from memory.common.db.models import Chunk
from memory.common import settings
logger = logging.getLogger(__name__)
class SearchResponse(BaseModel):
collection: str
results: list[dict]
def elide_content(content: str, max_length: int = 100) -> str:
if content and len(content) > max_length:
return content[:max_length] + "..."
return content
class SearchResult(BaseModel):
id: int
chunks: list[str]
size: int | None = None
mime_type: str | None = None
content: Optional[str | dict] = None
filename: Optional[str] = None
tags: list[str] | None = None
metadata: dict | None = None
created_at: datetime | None = None
@classmethod
def from_source_item(
cls, source: SourceItem, chunks: list[Chunk], previews: Optional[bool] = False
) -> "SearchResult":
metadata = source.display_contents or {}
metadata.pop("content", None)
chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
return cls(
id=cast(int, source.id),
size=cast(int, source.size),
mime_type=cast(str, source.mime_type),
chunks=[elide_content(str(chunk.content), chunk_size) for chunk in chunks],
content=elide_content(
cast(str, source.content),
settings.MAX_PREVIEW_LENGTH
if previews
else settings.MAX_NON_PREVIEW_LENGTH,
),
filename=cast(str, source.filename),
tags=cast(list[str], source.tags),
metadata=metadata,
created_at=cast(datetime | None, source.inserted_at),
)
class SearchFilters(TypedDict):
min_size: NotRequired[int]
max_size: NotRequired[int]
min_confidences: NotRequired[dict[str, float]]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]

View File

@ -1,140 +0,0 @@
import asyncio
import traceback
from datetime import datetime
import logging
from collections import defaultdict
from typing import Optional, TypedDict, NotRequired
from pydantic import BaseModel
from memory.common import settings
from memory.common.db.models import Chunk
logger = logging.getLogger(__name__)
class AnnotatedChunk(BaseModel):
id: str
score: float
metadata: dict
preview: Optional[str | None] = None
search_method: str | None = None
class SourceData(BaseModel):
"""Holds source item data to avoid SQLAlchemy session issues"""
id: int
size: int | None
mime_type: str | None
filename: str | None
content_length: int
contents: dict | str | None
created_at: datetime | None
@staticmethod
def from_chunk(chunk: Chunk) -> "SourceData":
source = chunk.source
display_contents = source.display_contents or {}
return SourceData(
id=source.id,
size=source.size,
mime_type=source.mime_type,
filename=source.filename,
content_length=len(source.content) if source.content else 0,
contents={k: v for k, v in display_contents.items() if v is not None},
created_at=source.inserted_at,
)
class SearchResponse(BaseModel):
collection: str
results: list[dict]
class SearchResult(BaseModel):
id: int
size: int
mime_type: str
chunks: list[AnnotatedChunk]
content: Optional[str | dict] = None
filename: Optional[str] = None
tags: list[str] | None = None
metadata: dict | None = None
created_at: datetime | None = None
class SearchFilters(TypedDict):
min_size: NotRequired[int]
max_size: NotRequired[int]
min_confidences: NotRequired[dict[str, float]]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]
async def with_timeout(
call, timeout: int = 2
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
Run a function with a timeout.
Args:
call: The function to run
timeout: The timeout in seconds
"""
try:
return await asyncio.wait_for(call, timeout=timeout)
except TimeoutError:
logger.warning(f"Search timed out after {timeout}s")
return []
except Exception as e:
traceback.print_exc()
logger.error(f"Search failed: {e}")
return []
def group_chunks(
chunks: list[tuple[SourceData, AnnotatedChunk]], preview: bool = False
) -> list[SearchResult]:
items = defaultdict(list)
source_lookup = {}
for source, chunk in chunks:
items[source.id].append(chunk)
source_lookup[source.id] = source
def get_content(text: str | dict | None) -> str | dict | None:
if isinstance(text, str) and len(text) > settings.MAX_PREVIEW_LENGTH:
return None
return text
def make_result(source: SourceData, chunks: list[AnnotatedChunk]) -> SearchResult:
contents = source.contents or {}
tags = []
if isinstance(contents, dict):
tags = contents.pop("tags", [])
content = contents.pop("content", None)
else:
content = contents
contents = {}
return SearchResult(
id=source.id,
size=source.size or source.content_length,
mime_type=source.mime_type or "text/plain",
filename=source.filename
and source.filename.replace(
str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
),
content=get_content(content),
tags=tags,
metadata=contents,
chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
created_at=source.created_at,
)
return [
make_result(source, chunks)
for source_id, chunks in items.items()
for source in [source_lookup[source_id]]
]

View File

@ -368,7 +368,7 @@ class SourceItem(Base):
return [cls.__tablename__]
@property
def display_contents(self) -> str | dict | None:
def display_contents(self) -> dict | None:
payload = self.as_payload()
payload.pop("source_id", None) # type: ignore
return {

View File

@ -135,7 +135,7 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240
# Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 8))
MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
# API settings