diff --git a/frontend/src/components/search/results.tsx b/frontend/src/components/search/results.tsx index 79ab3ba..ada94e5 100644 --- a/frontend/src/components/search/results.tsx +++ b/frontend/src/components/search/results.tsx @@ -90,6 +90,7 @@ export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {

{title}

+
{mime_type && mime_type?.startsWith('image/') && {title}}
@@ -115,7 +116,7 @@ export const Metadata = ({ metadata }: { metadata: any }) => { return (
    - {Object.entries(metadata).map(([key, value]) => ( + {Object.entries(metadata).filter(([key, value]) => ![null, undefined].includes(value as any)).map(([key, value]) => ( ))}
@@ -154,19 +155,19 @@ export const EmailResult = ({ content, tags, metadata }: SearchItem) => { } export const SearchResult = ({ result }: { result: SearchItem }) => { - if (result.mime_type.startsWith('image/')) { + if (result.mime_type?.startsWith('image/')) { return } - if (result.mime_type.startsWith('text/markdown')) { + if (result.mime_type?.startsWith('text/markdown')) { return } - if (result.mime_type.startsWith('text/')) { + if (result.mime_type?.startsWith('text/')) { return } - if (result.mime_type.startsWith('application/pdf')) { + if (result.mime_type?.startsWith('application/pdf')) { return } - if (result.mime_type.startsWith('message/rfc822')) { + if (result.mime_type?.startsWith('message/rfc822')) { return } console.log(result) diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/memory.py index ad84418..50837b6 100644 --- a/src/memory/api/MCP/memory.py +++ b/src/memory/api/MCP/memory.py @@ -109,14 +109,12 @@ async def search_knowledge_base( search_filters = SearchFilters(**filters) search_filters["source_ids"] = filter_source_ids(modalities, search_filters) - upload_data = extract.extract_text(query) + upload_data = extract.extract_text(query, skip_summary=True) results = await search( upload_data, previews=previews, modalities=modalities, limit=limit, - min_text_score=0.4, - min_multimodal_score=0.25, filters=search_filters, ) diff --git a/src/memory/api/search/__init__.py b/src/memory/api/search/__init__.py index f3b1cf1..932e3fe 100644 --- a/src/memory/api/search/__init__.py +++ b/src/memory/api/search/__init__.py @@ -1,4 +1,4 @@ from .search import search -from .utils import SearchResult, SearchFilters +from .types import SearchResult, SearchFilters __all__ = ["search", "SearchResult", "SearchFilters"] diff --git a/src/memory/api/search/bm25.py b/src/memory/api/search/bm25.py index 91e07f5..e20780f 100644 --- a/src/memory/api/search/bm25.py +++ b/src/memory/api/search/bm25.py @@ -2,13 +2,15 @@ Search endpoints for the knowledge base API. """ +import asyncio from hashlib import sha256 import logging import bm25s import Stemmer -from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters +from memory.api.search.types import SearchFilters +from memory.common import extract from memory.common.db.connection import make_session from memory.common.db.models import Chunk, ConfidenceScore @@ -20,7 +22,7 @@ async def search_bm25( modalities: set[str], limit: int = 10, filters: SearchFilters = SearchFilters(), -) -> list[tuple[SourceData, AnnotatedChunk]]: +) -> list[str]: with make_session() as db: items_query = db.query(Chunk.id, Chunk.content).filter( Chunk.collection_name.in_(modalities), @@ -65,21 +67,18 @@ async def search_bm25( item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score for doc, score in zip(results[0], scores[0]) } + return list(item_scores.keys()) - with make_session() as db: - chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all() - results = [] - for chunk in chunks: - # Prefetch all needed source data while in session - source_data = SourceData.from_chunk(chunk) - annotated = AnnotatedChunk( - id=str(chunk.id), - score=item_scores[chunk.id], - metadata=chunk.source.as_payload(), - preview=None, - search_method="bm25", - ) - results.append((source_data, annotated)) - - return results +async def search_bm25_chunks( + data: list[extract.DataChunk], + modalities: set[str] = set(), + limit: int = 10, + filters: SearchFilters = SearchFilters(), + timeout: int = 2, +) -> list[str]: + query = " ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]) + return await asyncio.wait_for( + search_bm25(query, modalities, limit, filters), + timeout, + ) diff --git a/src/memory/api/search/embeddings.py b/src/memory/api/search/embeddings.py index 717c0b3..dd5721b 100644 --- a/src/memory/api/search/embeddings.py +++ b/src/memory/api/search/embeddings.py @@ -1,65 +1,20 @@ -import base64 -import io import logging import asyncio -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, cast import qdrant_client -from PIL import Image from qdrant_client.http import models as qdrant_models -from memory.common import embedding, extract, qdrant, settings -from memory.common.db.connection import make_session -from memory.common.db.models import Chunk -from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters +from memory.common import embedding, extract, qdrant +from memory.common.collections import ( + MULTIMODAL_COLLECTIONS, + TEXT_COLLECTIONS, +) +from memory.api.search.types import SearchFilters logger = logging.getLogger(__name__) -def annotated_chunk( - chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool -) -> tuple[SourceData, AnnotatedChunk]: - def serialize_item(item: bytes | str | Image.Image) -> str | None: - if not previews and not isinstance(item, str): - return None - if ( - not previews - and isinstance(item, str) - and len(item) > settings.MAX_NON_PREVIEW_LENGTH - ): - return item[: settings.MAX_NON_PREVIEW_LENGTH] + "..." - elif isinstance(item, str): - if len(item) > settings.MAX_PREVIEW_LENGTH: - return None - return item - if isinstance(item, Image.Image): - buffer = io.BytesIO() - format = item.format or "PNG" - item.save(buffer, format=format) - mime_type = f"image/{format.lower()}" - return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" - elif isinstance(item, bytes): - return base64.b64encode(item).decode("utf-8") - else: - raise ValueError(f"Unsupported item type: {type(item)}") - - metadata = search_result.payload or {} - metadata = { - k: v - for k, v in metadata.items() - if k not in ["content", "filename", "size", "content_type", "tags"] - } - - # Prefetch all needed source data while in session - return SourceData.from_chunk(chunk), AnnotatedChunk( - id=str(chunk.id), - score=search_result.score, - metadata=metadata, - preview=serialize_item(chunk.data[0]) if chunk.data else None, - search_method="embeddings", - ) - - async def query_chunks( client: qdrant_client.QdrantClient, upload_data: list[extract.DataChunk], @@ -178,15 +133,14 @@ def merge_filters( return filters -async def search_embeddings( +async def search_chunks( data: list[extract.DataChunk], - previews: Optional[bool] = False, modalities: set[str] = set(), limit: int = 10, min_score: float = 0.3, filters: SearchFilters = {}, multimodal: bool = False, -) -> list[tuple[SourceData, AnnotatedChunk]]: +) -> list[str]: """ Search across knowledge base using text query and optional files. @@ -218,9 +172,38 @@ async def search_embeddings( found_chunks = { str(r.id): r for results in search_results.values() for r in results } - with make_session() as db: - chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all() - return [ - annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False) - for chunk in chunks - ] + return list(found_chunks.keys()) + + +async def search_chunks_embeddings( + data: list[extract.DataChunk], + modalities: set[str] = set(), + limit: int = 10, + filters: SearchFilters = SearchFilters(), + timeout: int = 2, +) -> list[str]: + all_ids = await asyncio.gather( + asyncio.wait_for( + search_chunks( + data, + modalities & TEXT_COLLECTIONS, + limit, + 0.4, + filters, + False, + ), + timeout, + ), + asyncio.wait_for( + search_chunks( + data, + modalities & MULTIMODAL_COLLECTIONS, + limit, + 0.25, + filters, + True, + ), + timeout, + ), + ) + return list({id for ids in all_ids for id in ids}) diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py index f876b0b..85b39c0 100644 --- a/src/memory/api/search/search.py +++ b/src/memory/api/search/search.py @@ -4,36 +4,70 @@ Search endpoints for the knowledge base API. import asyncio import logging +from collections import defaultdict from typing import Optional - +from sqlalchemy.orm import load_only from memory.common import extract, settings -from memory.common.collections import ( - ALL_COLLECTIONS, - MULTIMODAL_COLLECTIONS, - TEXT_COLLECTIONS, -) -from memory.api.search.embeddings import search_embeddings +from memory.common.db.connection import make_session +from memory.common.db.models import Chunk, SourceItem +from memory.common.collections import ALL_COLLECTIONS +from memory.api.search.embeddings import search_chunks_embeddings if settings.ENABLE_BM25_SEARCH: - from memory.api.search.bm25 import search_bm25 + from memory.api.search.bm25 import search_bm25_chunks -from memory.api.search.utils import ( - SearchFilters, - SearchResult, - group_chunks, - with_timeout, -) +from memory.api.search.types import SearchFilters, SearchResult logger = logging.getLogger(__name__) +async def search_chunks( + data: list[extract.DataChunk], + modalities: set[str] = set(), + limit: int = 10, + filters: SearchFilters = {}, + timeout: int = 2, +) -> list[Chunk]: + funcs = [search_chunks_embeddings] + if settings.ENABLE_BM25_SEARCH: + funcs.append(search_bm25_chunks) + + all_ids = await asyncio.gather( + *[func(data, modalities, limit, filters, timeout) for func in funcs] + ) + all_ids = {id for ids in all_ids for id in ids} + + with make_session() as db: + chunks = ( + db.query(Chunk) + .options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore + .filter(Chunk.id.in_(all_ids)) + .all() + ) + db.expunge_all() + return chunks + + +async def search_sources( + chunks: list[Chunk], previews: Optional[bool] = False +) -> list[SearchResult]: + by_source = defaultdict(list) + for chunk in chunks: + by_source[chunk.source_id].append(chunk) + + with make_session() as db: + sources = db.query(SourceItem).filter(SourceItem.id.in_(by_source.keys())).all() + return [ + SearchResult.from_source_item(source, by_source[source.id], previews) + for source in sources + ] + + async def search( data: list[extract.DataChunk], previews: Optional[bool] = False, modalities: set[str] = set(), limit: int = 10, - min_text_score: float = 0.4, - min_multimodal_score: float = 0.25, filters: SearchFilters = {}, timeout: int = 2, ) -> list[SearchResult]: @@ -50,56 +84,11 @@ async def search( - List of search results sorted by score """ allowed_modalities = modalities & ALL_COLLECTIONS.keys() - - searches = [] - if settings.ENABLE_EMBEDDING_SEARCH: - searches = [ - with_timeout( - search_embeddings( - data, - previews, - allowed_modalities & TEXT_COLLECTIONS, - limit, - min_text_score, - filters, - multimodal=False, - ), - timeout, - ), - with_timeout( - search_embeddings( - data, - previews, - allowed_modalities & MULTIMODAL_COLLECTIONS, - limit, - min_multimodal_score, - filters, - multimodal=True, - ), - timeout, - ), - ] - if settings.ENABLE_BM25_SEARCH: - searches.append( - with_timeout( - search_bm25( - " ".join( - [c for chunk in data for c in chunk.data if isinstance(c, str)] - ), - modalities, - limit=limit, - filters=filters, - ), - timeout, - ) - ) - - search_results = await asyncio.gather(*searches, return_exceptions=False) - all_results = [] - for results in search_results: - if len(all_results) >= limit: - break - all_results.extend(results) - - results = group_chunks(all_results, previews or False) - return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True) + chunks = await search_chunks( + data, + allowed_modalities, + limit, + filters, + timeout, + ) + return await search_sources(chunks, previews) diff --git a/src/memory/api/search/types.py b/src/memory/api/search/types.py new file mode 100644 index 0000000..def4a26 --- /dev/null +++ b/src/memory/api/search/types.py @@ -0,0 +1,67 @@ +from datetime import datetime +import logging +from typing import Optional, TypedDict, NotRequired, cast + +from memory.common.db.models.source_item import SourceItem +from pydantic import BaseModel + +from memory.common.db.models import Chunk +from memory.common import settings + +logger = logging.getLogger(__name__) + + +class SearchResponse(BaseModel): + collection: str + results: list[dict] + + +def elide_content(content: str, max_length: int = 100) -> str: + if content and len(content) > max_length: + return content[:max_length] + "..." + return content + + +class SearchResult(BaseModel): + id: int + chunks: list[str] + size: int | None = None + mime_type: str | None = None + content: Optional[str | dict] = None + filename: Optional[str] = None + tags: list[str] | None = None + metadata: dict | None = None + created_at: datetime | None = None + + @classmethod + def from_source_item( + cls, source: SourceItem, chunks: list[Chunk], previews: Optional[bool] = False + ) -> "SearchResult": + metadata = source.display_contents or {} + metadata.pop("content", None) + chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4 + + return cls( + id=cast(int, source.id), + size=cast(int, source.size), + mime_type=cast(str, source.mime_type), + chunks=[elide_content(str(chunk.content), chunk_size) for chunk in chunks], + content=elide_content( + cast(str, source.content), + settings.MAX_PREVIEW_LENGTH + if previews + else settings.MAX_NON_PREVIEW_LENGTH, + ), + filename=cast(str, source.filename), + tags=cast(list[str], source.tags), + metadata=metadata, + created_at=cast(datetime | None, source.inserted_at), + ) + + +class SearchFilters(TypedDict): + min_size: NotRequired[int] + max_size: NotRequired[int] + min_confidences: NotRequired[dict[str, float]] + observation_types: NotRequired[list[str] | None] + source_ids: NotRequired[list[int] | None] diff --git a/src/memory/api/search/utils.py b/src/memory/api/search/utils.py deleted file mode 100644 index 98ae1c1..0000000 --- a/src/memory/api/search/utils.py +++ /dev/null @@ -1,140 +0,0 @@ -import asyncio -import traceback -from datetime import datetime -import logging -from collections import defaultdict -from typing import Optional, TypedDict, NotRequired - -from pydantic import BaseModel - -from memory.common import settings -from memory.common.db.models import Chunk - -logger = logging.getLogger(__name__) - - -class AnnotatedChunk(BaseModel): - id: str - score: float - metadata: dict - preview: Optional[str | None] = None - search_method: str | None = None - - -class SourceData(BaseModel): - """Holds source item data to avoid SQLAlchemy session issues""" - - id: int - size: int | None - mime_type: str | None - filename: str | None - content_length: int - contents: dict | str | None - created_at: datetime | None - - @staticmethod - def from_chunk(chunk: Chunk) -> "SourceData": - source = chunk.source - display_contents = source.display_contents or {} - return SourceData( - id=source.id, - size=source.size, - mime_type=source.mime_type, - filename=source.filename, - content_length=len(source.content) if source.content else 0, - contents={k: v for k, v in display_contents.items() if v is not None}, - created_at=source.inserted_at, - ) - - -class SearchResponse(BaseModel): - collection: str - results: list[dict] - - -class SearchResult(BaseModel): - id: int - size: int - mime_type: str - chunks: list[AnnotatedChunk] - content: Optional[str | dict] = None - filename: Optional[str] = None - tags: list[str] | None = None - metadata: dict | None = None - created_at: datetime | None = None - - -class SearchFilters(TypedDict): - min_size: NotRequired[int] - max_size: NotRequired[int] - min_confidences: NotRequired[dict[str, float]] - observation_types: NotRequired[list[str] | None] - source_ids: NotRequired[list[int] | None] - - -async def with_timeout( - call, timeout: int = 2 -) -> list[tuple[SourceData, AnnotatedChunk]]: - """ - Run a function with a timeout. - - Args: - call: The function to run - timeout: The timeout in seconds - """ - try: - return await asyncio.wait_for(call, timeout=timeout) - except TimeoutError: - logger.warning(f"Search timed out after {timeout}s") - return [] - except Exception as e: - traceback.print_exc() - logger.error(f"Search failed: {e}") - return [] - - -def group_chunks( - chunks: list[tuple[SourceData, AnnotatedChunk]], preview: bool = False -) -> list[SearchResult]: - items = defaultdict(list) - source_lookup = {} - - for source, chunk in chunks: - items[source.id].append(chunk) - source_lookup[source.id] = source - - def get_content(text: str | dict | None) -> str | dict | None: - if isinstance(text, str) and len(text) > settings.MAX_PREVIEW_LENGTH: - return None - return text - - def make_result(source: SourceData, chunks: list[AnnotatedChunk]) -> SearchResult: - contents = source.contents or {} - tags = [] - if isinstance(contents, dict): - tags = contents.pop("tags", []) - content = contents.pop("content", None) - else: - content = contents - contents = {} - - return SearchResult( - id=source.id, - size=source.size or source.content_length, - mime_type=source.mime_type or "text/plain", - filename=source.filename - and source.filename.replace( - str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files" - ), - content=get_content(content), - tags=tags, - metadata=contents, - chunks=sorted(chunks, key=lambda x: x.score, reverse=True), - created_at=source.created_at, - ) - - return [ - make_result(source, chunks) - for source_id, chunks in items.items() - for source in [source_lookup[source_id]] - ] diff --git a/src/memory/common/db/models/source_item.py b/src/memory/common/db/models/source_item.py index b7ed5ae..0035921 100644 --- a/src/memory/common/db/models/source_item.py +++ b/src/memory/common/db/models/source_item.py @@ -368,7 +368,7 @@ class SourceItem(Base): return [cls.__tablename__] @property - def display_contents(self) -> str | dict | None: + def display_contents(self) -> dict | None: payload = self.as_payload() payload.pop("source_id", None) # type: ignore return { diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 4033351..2943506 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -135,7 +135,7 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240 # Search settings ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True) ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True) -MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 8)) +MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16)) MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000)) # API settings