diff --git a/frontend/src/components/search/Search.tsx b/frontend/src/components/search/Search.tsx index 02302ca..dc1218e 100644 --- a/frontend/src/components/search/Search.tsx +++ b/frontend/src/components/search/Search.tsx @@ -40,8 +40,7 @@ const Search = () => { setIsLoading(true) try { - console.log(params) - const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities) + const searchResults = await searchKnowledgeBase(params.query, params.modalities, params.filters, params.config) setResults(searchResults || []) } catch (error) { console.error('Search error:', error) diff --git a/frontend/src/components/search/SearchForm.tsx b/frontend/src/components/search/SearchForm.tsx index 717fa78..5b0a5c5 100644 --- a/frontend/src/components/search/SearchForm.tsx +++ b/frontend/src/components/search/SearchForm.tsx @@ -10,12 +10,17 @@ type Filter = { [key: string]: any } +type SearchConfig = { + previews: boolean + useScores: boolean + limit: number +} + export interface SearchParams { query: string - previews: boolean modalities: string[] filters: Filter - limit: number + config: SearchConfig } interface SearchFormProps { @@ -40,6 +45,7 @@ const cleanFilters = (filters: Record): Record => export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => { const [query, setQuery] = useState('') const [previews, setPreviews] = useState(false) + const [useScores, setUseScores] = useState(false) const [modalities, setModalities] = useState>({}) const [schemas, setSchemas] = useState>({}) const [tags, setTags] = useState>({}) @@ -68,13 +74,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => { onSearch({ query, - previews, modalities: getSelectedItems(modalities), + config: { + previews, + useScores, + limit + }, filters: { tags: getSelectedItems(tags), ...cleanFilters(dynamicFilters) }, - limit }) } @@ -105,6 +114,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => { Include content previews +
+ +
{ return (await mcpCall('get_metadata_schemas'))[0] }, [mcpCall]) - const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record = {}, modalities: string[] = []) => { + const searchKnowledgeBase = useCallback(async (query: string, modalities: string[] = [], filters: Record = {}, config: Record = {}) => { return await mcpCall('search_knowledge_base', { query, filters, + config, modalities, - previews, - limit, }) }, [mcpCall]) diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/memory.py index 05dd1dd..0eff43b 100644 --- a/src/memory/api/MCP/memory.py +++ b/src/memory/api/MCP/memory.py @@ -6,17 +6,16 @@ import base64 import logging import pathlib from datetime import datetime, timezone -from typing import Any from PIL import Image from pydantic import BaseModel from sqlalchemy import Text from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY -from mcp.server.fastmcp.resources.base import Resource from memory.api.MCP.tools import mcp -from memory.api.search.search import SearchFilters, search +from memory.api.search.search import search +from memory.api.search.types import SearchFilters, SearchConfig from memory.common import extract, settings from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION from memory.common.celery_app import app as celery_app @@ -80,22 +79,32 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] @mcp.tool() async def search_knowledge_base( query: str, - filters: dict[str, Any], + filters: SearchFilters, + config: SearchConfig = SearchConfig(), modalities: set[str] = set(), - previews: bool = False, - limit: int = 10, ) -> list[dict]: """ Search user's stored content including emails, documents, articles, books. Use to find specific information the user has saved or received. Combine with search_observations for complete user context. + Use the `get_metadata_schemas` tool to get the metadata schema for each collection, from which you can infer the keys for the filters dictionary. + + If you know what kind of data you're looking for, it's worth explicitly filtering by that modality, as this gives better results. Args: query: Natural language search query - be descriptive about what you're looking for - previews: Include actual content in results - when false only a snippet is returned modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all) - filters: Filter by tags, source_ids, etc. - limit: Max results (1-100) + filters: a dictionary with the following keys: + - tags: a list of tags to filter by + - source_ids: a list of source ids to filter by + - min_size: the minimum size of the content to filter by + - max_size: the maximum size of the content to filter by + - min_created_at: the minimum created_at date to filter by + - max_created_at: the maximum created_at date to filter by + config: a dictionary with the following keys: + - limit: the maximum number of results to return + - previews: whether to include the actual content in the results (up to MAX_PREVIEW_LENGTH characters) + - useScores: whether to score the results with a LLM before returning - this results in better results but is slower Returns: List of search results with id, score, chunks, content, filename Higher scores (>0.7) indicate strong matches. @@ -112,10 +121,9 @@ async def search_knowledge_base( upload_data = extract.extract_text(query, skip_summary=True) results = await search( upload_data, - previews=previews, modalities=modalities, - limit=limit, filters=search_filters, + config=config, ) return [result.model_dump() for result in results] @@ -211,7 +219,7 @@ async def search_observations( tags: list[str] | None = None, observation_types: list[str] | None = None, min_confidences: dict[str, float] = {}, - limit: int = 10, + config: SearchConfig = SearchConfig(), ) -> list[dict]: """ Search recorded observations about the user. @@ -224,7 +232,7 @@ async def search_observations( tags: Filter by tags (must have at least one matching tag) observation_types: Filter by: belief, preference, behavior, contradiction, general min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8} - limit: Max results (1-100) + config: SearchConfig Returns: List with content, tags, created_at, metadata Results sorted by relevance to your query. @@ -246,9 +254,7 @@ async def search_observations( extract.DataChunk(data=[semantic_text]), extract.DataChunk(data=[temporal]), ], - previews=True, modalities={"semantic", "temporal"}, - limit=limit, filters=SearchFilters( subject=subject, min_confidences=min_confidences, @@ -256,7 +262,7 @@ async def search_observations( observation_types=observation_types, source_ids=filter_observation_source_ids(tags=tags), ), - timeout=2, + config=config, ) return [ diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py index 3316659..387a580 100644 --- a/src/memory/api/search/search.py +++ b/src/memory/api/search/search.py @@ -17,7 +17,7 @@ from memory.api.search import scorer if settings.ENABLE_BM25_SEARCH: from memory.api.search.bm25 import search_bm25_chunks -from memory.api.search.types import SearchFilters, SearchResult +from memory.api.search.types import SearchConfig, SearchFilters, SearchResult logger = logging.getLogger(__name__) @@ -57,7 +57,7 @@ async def search_chunks( async def search_sources( - chunks: list[Chunk], previews: Optional[bool] = False + chunks: list[Chunk], previews: bool = False ) -> list[SearchResult]: by_source = defaultdict(list) for chunk in chunks: @@ -73,11 +73,9 @@ async def search_sources( async def search( data: list[extract.DataChunk], - previews: Optional[bool] = False, modalities: set[str] = set(), - limit: int = 10, filters: SearchFilters = {}, - timeout: int = 20, + config: SearchConfig = SearchConfig(), ) -> list[SearchResult]: """ Search across knowledge base using text query and optional files. @@ -95,13 +93,13 @@ async def search( chunks = await search_chunks( data, allowed_modalities, - limit, + config.limit, filters, - timeout, + config.timeout, ) - if settings.ENABLE_SEARCH_SCORING: + if settings.ENABLE_SEARCH_SCORING and config.useScores: chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3) - sources = await search_sources(chunks, previews) + sources = await search_sources(chunks, config.previews) sources.sort(key=lambda x: x.search_score or 0, reverse=True) - return sources[:limit] + return sources[: config.limit] diff --git a/src/memory/api/search/types.py b/src/memory/api/search/types.py index b6ddd74..ab13734 100644 --- a/src/memory/api/search/types.py +++ b/src/memory/api/search/types.py @@ -2,10 +2,9 @@ from datetime import datetime import logging from typing import Optional, TypedDict, NotRequired, cast -from memory.common.db.models.source_item import SourceItem from pydantic import BaseModel -from memory.common.db.models import Chunk +from memory.common.db.models import Chunk, SourceItem from memory.common import settings logger = logging.getLogger(__name__) @@ -69,3 +68,10 @@ class SearchFilters(TypedDict): min_confidences: NotRequired[dict[str, float]] observation_types: NotRequired[list[str] | None] source_ids: NotRequired[list[int] | None] + + +class SearchConfig(BaseModel): + limit: int = 10 + timeout: int = 20 + previews: bool = False + useScores: bool = False