@@ -115,7 +116,7 @@ export const Metadata = ({ metadata }: { metadata: any }) => {
return (
- {Object.entries(metadata).map(([key, value]) => (
+ {Object.entries(metadata).filter(([key, value]) => ![null, undefined].includes(value as any)).map(([key, value]) => (
))}
@@ -154,19 +155,19 @@ export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
}
export const SearchResult = ({ result }: { result: SearchItem }) => {
- if (result.mime_type.startsWith('image/')) {
+ if (result.mime_type?.startsWith('image/')) {
return
}
- if (result.mime_type.startsWith('text/markdown')) {
+ if (result.mime_type?.startsWith('text/markdown')) {
return
}
- if (result.mime_type.startsWith('text/')) {
+ if (result.mime_type?.startsWith('text/')) {
return
}
- if (result.mime_type.startsWith('application/pdf')) {
+ if (result.mime_type?.startsWith('application/pdf')) {
return
}
- if (result.mime_type.startsWith('message/rfc822')) {
+ if (result.mime_type?.startsWith('message/rfc822')) {
return
}
console.log(result)
diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/memory.py
index ad84418..50837b6 100644
--- a/src/memory/api/MCP/memory.py
+++ b/src/memory/api/MCP/memory.py
@@ -109,14 +109,12 @@ async def search_knowledge_base(
search_filters = SearchFilters(**filters)
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
- upload_data = extract.extract_text(query)
+ upload_data = extract.extract_text(query, skip_summary=True)
results = await search(
upload_data,
previews=previews,
modalities=modalities,
limit=limit,
- min_text_score=0.4,
- min_multimodal_score=0.25,
filters=search_filters,
)
diff --git a/src/memory/api/search/__init__.py b/src/memory/api/search/__init__.py
index f3b1cf1..932e3fe 100644
--- a/src/memory/api/search/__init__.py
+++ b/src/memory/api/search/__init__.py
@@ -1,4 +1,4 @@
from .search import search
-from .utils import SearchResult, SearchFilters
+from .types import SearchResult, SearchFilters
__all__ = ["search", "SearchResult", "SearchFilters"]
diff --git a/src/memory/api/search/bm25.py b/src/memory/api/search/bm25.py
index 91e07f5..e20780f 100644
--- a/src/memory/api/search/bm25.py
+++ b/src/memory/api/search/bm25.py
@@ -2,13 +2,15 @@
Search endpoints for the knowledge base API.
"""
+import asyncio
from hashlib import sha256
import logging
import bm25s
import Stemmer
-from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
+from memory.api.search.types import SearchFilters
+from memory.common import extract
from memory.common.db.connection import make_session
from memory.common.db.models import Chunk, ConfidenceScore
@@ -20,7 +22,7 @@ async def search_bm25(
modalities: set[str],
limit: int = 10,
filters: SearchFilters = SearchFilters(),
-) -> list[tuple[SourceData, AnnotatedChunk]]:
+) -> list[str]:
with make_session() as db:
items_query = db.query(Chunk.id, Chunk.content).filter(
Chunk.collection_name.in_(modalities),
@@ -65,21 +67,18 @@ async def search_bm25(
item_ids[sha256(doc.encode("utf-8")).hexdigest()]: score
for doc, score in zip(results[0], scores[0])
}
+ return list(item_scores.keys())
- with make_session() as db:
- chunks = db.query(Chunk).filter(Chunk.id.in_(item_scores.keys())).all()
- results = []
- for chunk in chunks:
- # Prefetch all needed source data while in session
- source_data = SourceData.from_chunk(chunk)
- annotated = AnnotatedChunk(
- id=str(chunk.id),
- score=item_scores[chunk.id],
- metadata=chunk.source.as_payload(),
- preview=None,
- search_method="bm25",
- )
- results.append((source_data, annotated))
-
- return results
+async def search_bm25_chunks(
+ data: list[extract.DataChunk],
+ modalities: set[str] = set(),
+ limit: int = 10,
+ filters: SearchFilters = SearchFilters(),
+ timeout: int = 2,
+) -> list[str]:
+ query = " ".join([c for chunk in data for c in chunk.data if isinstance(c, str)])
+ return await asyncio.wait_for(
+ search_bm25(query, modalities, limit, filters),
+ timeout,
+ )
diff --git a/src/memory/api/search/embeddings.py b/src/memory/api/search/embeddings.py
index 717c0b3..dd5721b 100644
--- a/src/memory/api/search/embeddings.py
+++ b/src/memory/api/search/embeddings.py
@@ -1,65 +1,20 @@
-import base64
-import io
import logging
import asyncio
-from typing import Any, Callable, Optional, cast
+from typing import Any, Callable, cast
import qdrant_client
-from PIL import Image
from qdrant_client.http import models as qdrant_models
-from memory.common import embedding, extract, qdrant, settings
-from memory.common.db.connection import make_session
-from memory.common.db.models import Chunk
-from memory.api.search.utils import SourceData, AnnotatedChunk, SearchFilters
+from memory.common import embedding, extract, qdrant
+from memory.common.collections import (
+ MULTIMODAL_COLLECTIONS,
+ TEXT_COLLECTIONS,
+)
+from memory.api.search.types import SearchFilters
logger = logging.getLogger(__name__)
-def annotated_chunk(
- chunk: Chunk, search_result: qdrant_models.ScoredPoint, previews: bool
-) -> tuple[SourceData, AnnotatedChunk]:
- def serialize_item(item: bytes | str | Image.Image) -> str | None:
- if not previews and not isinstance(item, str):
- return None
- if (
- not previews
- and isinstance(item, str)
- and len(item) > settings.MAX_NON_PREVIEW_LENGTH
- ):
- return item[: settings.MAX_NON_PREVIEW_LENGTH] + "..."
- elif isinstance(item, str):
- if len(item) > settings.MAX_PREVIEW_LENGTH:
- return None
- return item
- if isinstance(item, Image.Image):
- buffer = io.BytesIO()
- format = item.format or "PNG"
- item.save(buffer, format=format)
- mime_type = f"image/{format.lower()}"
- return f"data:{mime_type};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
- elif isinstance(item, bytes):
- return base64.b64encode(item).decode("utf-8")
- else:
- raise ValueError(f"Unsupported item type: {type(item)}")
-
- metadata = search_result.payload or {}
- metadata = {
- k: v
- for k, v in metadata.items()
- if k not in ["content", "filename", "size", "content_type", "tags"]
- }
-
- # Prefetch all needed source data while in session
- return SourceData.from_chunk(chunk), AnnotatedChunk(
- id=str(chunk.id),
- score=search_result.score,
- metadata=metadata,
- preview=serialize_item(chunk.data[0]) if chunk.data else None,
- search_method="embeddings",
- )
-
-
async def query_chunks(
client: qdrant_client.QdrantClient,
upload_data: list[extract.DataChunk],
@@ -178,15 +133,14 @@ def merge_filters(
return filters
-async def search_embeddings(
+async def search_chunks(
data: list[extract.DataChunk],
- previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_score: float = 0.3,
filters: SearchFilters = {},
multimodal: bool = False,
-) -> list[tuple[SourceData, AnnotatedChunk]]:
+) -> list[str]:
"""
Search across knowledge base using text query and optional files.
@@ -218,9 +172,38 @@ async def search_embeddings(
found_chunks = {
str(r.id): r for results in search_results.values() for r in results
}
- with make_session() as db:
- chunks = db.query(Chunk).filter(Chunk.id.in_(found_chunks.keys())).all()
- return [
- annotated_chunk(chunk, found_chunks[str(chunk.id)], previews or False)
- for chunk in chunks
- ]
+ return list(found_chunks.keys())
+
+
+async def search_chunks_embeddings(
+ data: list[extract.DataChunk],
+ modalities: set[str] = set(),
+ limit: int = 10,
+ filters: SearchFilters = SearchFilters(),
+ timeout: int = 2,
+) -> list[str]:
+ all_ids = await asyncio.gather(
+ asyncio.wait_for(
+ search_chunks(
+ data,
+ modalities & TEXT_COLLECTIONS,
+ limit,
+ 0.4,
+ filters,
+ False,
+ ),
+ timeout,
+ ),
+ asyncio.wait_for(
+ search_chunks(
+ data,
+ modalities & MULTIMODAL_COLLECTIONS,
+ limit,
+ 0.25,
+ filters,
+ True,
+ ),
+ timeout,
+ ),
+ )
+ return list({id for ids in all_ids for id in ids})
diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py
index f876b0b..85b39c0 100644
--- a/src/memory/api/search/search.py
+++ b/src/memory/api/search/search.py
@@ -4,36 +4,70 @@ Search endpoints for the knowledge base API.
import asyncio
import logging
+from collections import defaultdict
from typing import Optional
-
+from sqlalchemy.orm import load_only
from memory.common import extract, settings
-from memory.common.collections import (
- ALL_COLLECTIONS,
- MULTIMODAL_COLLECTIONS,
- TEXT_COLLECTIONS,
-)
-from memory.api.search.embeddings import search_embeddings
+from memory.common.db.connection import make_session
+from memory.common.db.models import Chunk, SourceItem
+from memory.common.collections import ALL_COLLECTIONS
+from memory.api.search.embeddings import search_chunks_embeddings
if settings.ENABLE_BM25_SEARCH:
- from memory.api.search.bm25 import search_bm25
+ from memory.api.search.bm25 import search_bm25_chunks
-from memory.api.search.utils import (
- SearchFilters,
- SearchResult,
- group_chunks,
- with_timeout,
-)
+from memory.api.search.types import SearchFilters, SearchResult
logger = logging.getLogger(__name__)
+async def search_chunks(
+ data: list[extract.DataChunk],
+ modalities: set[str] = set(),
+ limit: int = 10,
+ filters: SearchFilters = {},
+ timeout: int = 2,
+) -> list[Chunk]:
+ funcs = [search_chunks_embeddings]
+ if settings.ENABLE_BM25_SEARCH:
+ funcs.append(search_bm25_chunks)
+
+ all_ids = await asyncio.gather(
+ *[func(data, modalities, limit, filters, timeout) for func in funcs]
+ )
+ all_ids = {id for ids in all_ids for id in ids}
+
+ with make_session() as db:
+ chunks = (
+ db.query(Chunk)
+ .options(load_only(Chunk.id, Chunk.source_id, Chunk.content)) # type: ignore
+ .filter(Chunk.id.in_(all_ids))
+ .all()
+ )
+ db.expunge_all()
+ return chunks
+
+
+async def search_sources(
+ chunks: list[Chunk], previews: Optional[bool] = False
+) -> list[SearchResult]:
+ by_source = defaultdict(list)
+ for chunk in chunks:
+ by_source[chunk.source_id].append(chunk)
+
+ with make_session() as db:
+ sources = db.query(SourceItem).filter(SourceItem.id.in_(by_source.keys())).all()
+ return [
+ SearchResult.from_source_item(source, by_source[source.id], previews)
+ for source in sources
+ ]
+
+
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
- min_text_score: float = 0.4,
- min_multimodal_score: float = 0.25,
filters: SearchFilters = {},
timeout: int = 2,
) -> list[SearchResult]:
@@ -50,56 +84,11 @@ async def search(
- List of search results sorted by score
"""
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
-
- searches = []
- if settings.ENABLE_EMBEDDING_SEARCH:
- searches = [
- with_timeout(
- search_embeddings(
- data,
- previews,
- allowed_modalities & TEXT_COLLECTIONS,
- limit,
- min_text_score,
- filters,
- multimodal=False,
- ),
- timeout,
- ),
- with_timeout(
- search_embeddings(
- data,
- previews,
- allowed_modalities & MULTIMODAL_COLLECTIONS,
- limit,
- min_multimodal_score,
- filters,
- multimodal=True,
- ),
- timeout,
- ),
- ]
- if settings.ENABLE_BM25_SEARCH:
- searches.append(
- with_timeout(
- search_bm25(
- " ".join(
- [c for chunk in data for c in chunk.data if isinstance(c, str)]
- ),
- modalities,
- limit=limit,
- filters=filters,
- ),
- timeout,
- )
- )
-
- search_results = await asyncio.gather(*searches, return_exceptions=False)
- all_results = []
- for results in search_results:
- if len(all_results) >= limit:
- break
- all_results.extend(results)
-
- results = group_chunks(all_results, previews or False)
- return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
+ chunks = await search_chunks(
+ data,
+ allowed_modalities,
+ limit,
+ filters,
+ timeout,
+ )
+ return await search_sources(chunks, previews)
diff --git a/src/memory/api/search/types.py b/src/memory/api/search/types.py
new file mode 100644
index 0000000..def4a26
--- /dev/null
+++ b/src/memory/api/search/types.py
@@ -0,0 +1,67 @@
+from datetime import datetime
+import logging
+from typing import Optional, TypedDict, NotRequired, cast
+
+from memory.common.db.models.source_item import SourceItem
+from pydantic import BaseModel
+
+from memory.common.db.models import Chunk
+from memory.common import settings
+
+logger = logging.getLogger(__name__)
+
+
+class SearchResponse(BaseModel):
+ collection: str
+ results: list[dict]
+
+
+def elide_content(content: str, max_length: int = 100) -> str:
+ if content and len(content) > max_length:
+ return content[:max_length] + "..."
+ return content
+
+
+class SearchResult(BaseModel):
+ id: int
+ chunks: list[str]
+ size: int | None = None
+ mime_type: str | None = None
+ content: Optional[str | dict] = None
+ filename: Optional[str] = None
+ tags: list[str] | None = None
+ metadata: dict | None = None
+ created_at: datetime | None = None
+
+ @classmethod
+ def from_source_item(
+ cls, source: SourceItem, chunks: list[Chunk], previews: Optional[bool] = False
+ ) -> "SearchResult":
+ metadata = source.display_contents or {}
+ metadata.pop("content", None)
+ chunk_size = settings.DEFAULT_CHUNK_TOKENS * 4
+
+ return cls(
+ id=cast(int, source.id),
+ size=cast(int, source.size),
+ mime_type=cast(str, source.mime_type),
+ chunks=[elide_content(str(chunk.content), chunk_size) for chunk in chunks],
+ content=elide_content(
+ cast(str, source.content),
+ settings.MAX_PREVIEW_LENGTH
+ if previews
+ else settings.MAX_NON_PREVIEW_LENGTH,
+ ),
+ filename=cast(str, source.filename),
+ tags=cast(list[str], source.tags),
+ metadata=metadata,
+ created_at=cast(datetime | None, source.inserted_at),
+ )
+
+
+class SearchFilters(TypedDict):
+ min_size: NotRequired[int]
+ max_size: NotRequired[int]
+ min_confidences: NotRequired[dict[str, float]]
+ observation_types: NotRequired[list[str] | None]
+ source_ids: NotRequired[list[int] | None]
diff --git a/src/memory/api/search/utils.py b/src/memory/api/search/utils.py
deleted file mode 100644
index 98ae1c1..0000000
--- a/src/memory/api/search/utils.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import asyncio
-import traceback
-from datetime import datetime
-import logging
-from collections import defaultdict
-from typing import Optional, TypedDict, NotRequired
-
-from pydantic import BaseModel
-
-from memory.common import settings
-from memory.common.db.models import Chunk
-
-logger = logging.getLogger(__name__)
-
-
-class AnnotatedChunk(BaseModel):
- id: str
- score: float
- metadata: dict
- preview: Optional[str | None] = None
- search_method: str | None = None
-
-
-class SourceData(BaseModel):
- """Holds source item data to avoid SQLAlchemy session issues"""
-
- id: int
- size: int | None
- mime_type: str | None
- filename: str | None
- content_length: int
- contents: dict | str | None
- created_at: datetime | None
-
- @staticmethod
- def from_chunk(chunk: Chunk) -> "SourceData":
- source = chunk.source
- display_contents = source.display_contents or {}
- return SourceData(
- id=source.id,
- size=source.size,
- mime_type=source.mime_type,
- filename=source.filename,
- content_length=len(source.content) if source.content else 0,
- contents={k: v for k, v in display_contents.items() if v is not None},
- created_at=source.inserted_at,
- )
-
-
-class SearchResponse(BaseModel):
- collection: str
- results: list[dict]
-
-
-class SearchResult(BaseModel):
- id: int
- size: int
- mime_type: str
- chunks: list[AnnotatedChunk]
- content: Optional[str | dict] = None
- filename: Optional[str] = None
- tags: list[str] | None = None
- metadata: dict | None = None
- created_at: datetime | None = None
-
-
-class SearchFilters(TypedDict):
- min_size: NotRequired[int]
- max_size: NotRequired[int]
- min_confidences: NotRequired[dict[str, float]]
- observation_types: NotRequired[list[str] | None]
- source_ids: NotRequired[list[int] | None]
-
-
-async def with_timeout(
- call, timeout: int = 2
-) -> list[tuple[SourceData, AnnotatedChunk]]:
- """
- Run a function with a timeout.
-
- Args:
- call: The function to run
- timeout: The timeout in seconds
- """
- try:
- return await asyncio.wait_for(call, timeout=timeout)
- except TimeoutError:
- logger.warning(f"Search timed out after {timeout}s")
- return []
- except Exception as e:
- traceback.print_exc()
- logger.error(f"Search failed: {e}")
- return []
-
-
-def group_chunks(
- chunks: list[tuple[SourceData, AnnotatedChunk]], preview: bool = False
-) -> list[SearchResult]:
- items = defaultdict(list)
- source_lookup = {}
-
- for source, chunk in chunks:
- items[source.id].append(chunk)
- source_lookup[source.id] = source
-
- def get_content(text: str | dict | None) -> str | dict | None:
- if isinstance(text, str) and len(text) > settings.MAX_PREVIEW_LENGTH:
- return None
- return text
-
- def make_result(source: SourceData, chunks: list[AnnotatedChunk]) -> SearchResult:
- contents = source.contents or {}
- tags = []
- if isinstance(contents, dict):
- tags = contents.pop("tags", [])
- content = contents.pop("content", None)
- else:
- content = contents
- contents = {}
-
- return SearchResult(
- id=source.id,
- size=source.size or source.content_length,
- mime_type=source.mime_type or "text/plain",
- filename=source.filename
- and source.filename.replace(
- str(settings.FILE_STORAGE_DIR).lstrip("/"), "/files"
- ),
- content=get_content(content),
- tags=tags,
- metadata=contents,
- chunks=sorted(chunks, key=lambda x: x.score, reverse=True),
- created_at=source.created_at,
- )
-
- return [
- make_result(source, chunks)
- for source_id, chunks in items.items()
- for source in [source_lookup[source_id]]
- ]
diff --git a/src/memory/common/db/models/source_item.py b/src/memory/common/db/models/source_item.py
index b7ed5ae..0035921 100644
--- a/src/memory/common/db/models/source_item.py
+++ b/src/memory/common/db/models/source_item.py
@@ -368,7 +368,7 @@ class SourceItem(Base):
return [cls.__tablename__]
@property
- def display_contents(self) -> str | dict | None:
+ def display_contents(self) -> dict | None:
payload = self.as_payload()
payload.pop("source_id", None) # type: ignore
return {
diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py
index 4033351..2943506 100644
--- a/src/memory/common/settings.py
+++ b/src/memory/common/settings.py
@@ -135,7 +135,7 @@ SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240
# Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
-MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 8))
+MAX_PREVIEW_LENGTH = int(os.getenv("MAX_PREVIEW_LENGTH", DEFAULT_CHUNK_TOKENS * 16))
MAX_NON_PREVIEW_LENGTH = int(os.getenv("MAX_NON_PREVIEW_LENGTH", 2000))
# API settings