more improvements

This commit is contained in:
Daniel O'Connell 2025-06-28 22:14:14 +02:00
parent 1276b83ffb
commit 387bd962e6
6 changed files with 64 additions and 37 deletions

View File

@ -40,8 +40,7 @@ const Search = () => {
setIsLoading(true) setIsLoading(true)
try { try {
console.log(params) const searchResults = await searchKnowledgeBase(params.query, params.modalities, params.filters, params.config)
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
setResults(searchResults || []) setResults(searchResults || [])
} catch (error) { } catch (error) {
console.error('Search error:', error) console.error('Search error:', error)

View File

@ -10,12 +10,17 @@ type Filter = {
[key: string]: any [key: string]: any
} }
type SearchConfig = {
previews: boolean
useScores: boolean
limit: number
}
export interface SearchParams { export interface SearchParams {
query: string query: string
previews: boolean
modalities: string[] modalities: string[]
filters: Filter filters: Filter
limit: number config: SearchConfig
} }
interface SearchFormProps { interface SearchFormProps {
@ -40,6 +45,7 @@ const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => { export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
const [query, setQuery] = useState('') const [query, setQuery] = useState('')
const [previews, setPreviews] = useState(false) const [previews, setPreviews] = useState(false)
const [useScores, setUseScores] = useState(false)
const [modalities, setModalities] = useState<Record<string, boolean>>({}) const [modalities, setModalities] = useState<Record<string, boolean>>({})
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({}) const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
const [tags, setTags] = useState<Record<string, boolean>>({}) const [tags, setTags] = useState<Record<string, boolean>>({})
@ -68,13 +74,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
onSearch({ onSearch({
query, query,
previews,
modalities: getSelectedItems(modalities), modalities: getSelectedItems(modalities),
config: {
previews,
useScores,
limit
},
filters: { filters: {
tags: getSelectedItems(tags), tags: getSelectedItems(tags),
...cleanFilters(dynamicFilters) ...cleanFilters(dynamicFilters)
}, },
limit
}) })
} }
@ -105,6 +114,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
Include content previews Include content previews
</label> </label>
</div> </div>
<div className="search-option">
<label>
<input
type="checkbox"
checked={useScores}
onChange={(e) => setUseScores(e.target.checked)}
/>
Score results with a LLM before returning
</label>
</div>
<SelectableTags <SelectableTags
title="Modalities" title="Modalities"

View File

@ -151,13 +151,12 @@ export const useMCP = () => {
return (await mcpCall('get_metadata_schemas'))[0] return (await mcpCall('get_metadata_schemas'))[0]
}, [mcpCall]) }, [mcpCall])
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => { const searchKnowledgeBase = useCallback(async (query: string, modalities: string[] = [], filters: Record<string, any> = {}, config: Record<string, any> = {}) => {
return await mcpCall('search_knowledge_base', { return await mcpCall('search_knowledge_base', {
query, query,
filters, filters,
config,
modalities, modalities,
previews,
limit,
}) })
}, [mcpCall]) }, [mcpCall])

View File

@ -6,17 +6,16 @@ import base64
import logging import logging
import pathlib import pathlib
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any
from PIL import Image from PIL import Image
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Text from sqlalchemy import Text
from sqlalchemy import cast as sql_cast from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import ARRAY
from mcp.server.fastmcp.resources.base import Resource
from memory.api.MCP.tools import mcp from memory.api.MCP.tools import mcp
from memory.api.search.search import SearchFilters, search from memory.api.search.search import search
from memory.api.search.types import SearchFilters, SearchConfig
from memory.common import extract, settings from memory.common import extract, settings
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
from memory.common.celery_app import app as celery_app from memory.common.celery_app import app as celery_app
@ -80,22 +79,32 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
@mcp.tool() @mcp.tool()
async def search_knowledge_base( async def search_knowledge_base(
query: str, query: str,
filters: dict[str, Any], filters: SearchFilters,
config: SearchConfig = SearchConfig(),
modalities: set[str] = set(), modalities: set[str] = set(),
previews: bool = False,
limit: int = 10,
) -> list[dict]: ) -> list[dict]:
""" """
Search user's stored content including emails, documents, articles, books. Search user's stored content including emails, documents, articles, books.
Use to find specific information the user has saved or received. Use to find specific information the user has saved or received.
Combine with search_observations for complete user context. Combine with search_observations for complete user context.
Use the `get_metadata_schemas` tool to get the metadata schema for each collection, from which you can infer the keys for the filters dictionary.
If you know what kind of data you're looking for, it's worth explicitly filtering by that modality, as this gives better results.
Args: Args:
query: Natural language search query - be descriptive about what you're looking for query: Natural language search query - be descriptive about what you're looking for
previews: Include actual content in results - when false only a snippet is returned
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all) modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
filters: Filter by tags, source_ids, etc. filters: a dictionary with the following keys:
limit: Max results (1-100) - tags: a list of tags to filter by
- source_ids: a list of source ids to filter by
- min_size: the minimum size of the content to filter by
- max_size: the maximum size of the content to filter by
- min_created_at: the minimum created_at date to filter by
- max_created_at: the maximum created_at date to filter by
config: a dictionary with the following keys:
- limit: the maximum number of results to return
- previews: whether to include the actual content in the results (up to MAX_PREVIEW_LENGTH characters)
- useScores: whether to score the results with a LLM before returning - this results in better results but is slower
Returns: List of search results with id, score, chunks, content, filename Returns: List of search results with id, score, chunks, content, filename
Higher scores (>0.7) indicate strong matches. Higher scores (>0.7) indicate strong matches.
@ -112,10 +121,9 @@ async def search_knowledge_base(
upload_data = extract.extract_text(query, skip_summary=True) upload_data = extract.extract_text(query, skip_summary=True)
results = await search( results = await search(
upload_data, upload_data,
previews=previews,
modalities=modalities, modalities=modalities,
limit=limit,
filters=search_filters, filters=search_filters,
config=config,
) )
return [result.model_dump() for result in results] return [result.model_dump() for result in results]
@ -211,7 +219,7 @@ async def search_observations(
tags: list[str] | None = None, tags: list[str] | None = None,
observation_types: list[str] | None = None, observation_types: list[str] | None = None,
min_confidences: dict[str, float] = {}, min_confidences: dict[str, float] = {},
limit: int = 10, config: SearchConfig = SearchConfig(),
) -> list[dict]: ) -> list[dict]:
""" """
Search recorded observations about the user. Search recorded observations about the user.
@ -224,7 +232,7 @@ async def search_observations(
tags: Filter by tags (must have at least one matching tag) tags: Filter by tags (must have at least one matching tag)
observation_types: Filter by: belief, preference, behavior, contradiction, general observation_types: Filter by: belief, preference, behavior, contradiction, general
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8} min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
limit: Max results (1-100) config: SearchConfig
Returns: List with content, tags, created_at, metadata Returns: List with content, tags, created_at, metadata
Results sorted by relevance to your query. Results sorted by relevance to your query.
@ -246,9 +254,7 @@ async def search_observations(
extract.DataChunk(data=[semantic_text]), extract.DataChunk(data=[semantic_text]),
extract.DataChunk(data=[temporal]), extract.DataChunk(data=[temporal]),
], ],
previews=True,
modalities={"semantic", "temporal"}, modalities={"semantic", "temporal"},
limit=limit,
filters=SearchFilters( filters=SearchFilters(
subject=subject, subject=subject,
min_confidences=min_confidences, min_confidences=min_confidences,
@ -256,7 +262,7 @@ async def search_observations(
observation_types=observation_types, observation_types=observation_types,
source_ids=filter_observation_source_ids(tags=tags), source_ids=filter_observation_source_ids(tags=tags),
), ),
timeout=2, config=config,
) )
return [ return [

View File

@ -17,7 +17,7 @@ from memory.api.search import scorer
if settings.ENABLE_BM25_SEARCH: if settings.ENABLE_BM25_SEARCH:
from memory.api.search.bm25 import search_bm25_chunks from memory.api.search.bm25 import search_bm25_chunks
from memory.api.search.types import SearchFilters, SearchResult from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +57,7 @@ async def search_chunks(
async def search_sources( async def search_sources(
chunks: list[Chunk], previews: Optional[bool] = False chunks: list[Chunk], previews: bool = False
) -> list[SearchResult]: ) -> list[SearchResult]:
by_source = defaultdict(list) by_source = defaultdict(list)
for chunk in chunks: for chunk in chunks:
@ -73,11 +73,9 @@ async def search_sources(
async def search( async def search(
data: list[extract.DataChunk], data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(), modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = {}, filters: SearchFilters = {},
timeout: int = 20, config: SearchConfig = SearchConfig(),
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
Search across knowledge base using text query and optional files. Search across knowledge base using text query and optional files.
@ -95,13 +93,13 @@ async def search(
chunks = await search_chunks( chunks = await search_chunks(
data, data,
allowed_modalities, allowed_modalities,
limit, config.limit,
filters, filters,
timeout, config.timeout,
) )
if settings.ENABLE_SEARCH_SCORING: if settings.ENABLE_SEARCH_SCORING and config.useScores:
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3) chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
sources = await search_sources(chunks, previews) sources = await search_sources(chunks, config.previews)
sources.sort(key=lambda x: x.search_score or 0, reverse=True) sources.sort(key=lambda x: x.search_score or 0, reverse=True)
return sources[:limit] return sources[: config.limit]

View File

@ -2,10 +2,9 @@ from datetime import datetime
import logging import logging
from typing import Optional, TypedDict, NotRequired, cast from typing import Optional, TypedDict, NotRequired, cast
from memory.common.db.models.source_item import SourceItem
from pydantic import BaseModel from pydantic import BaseModel
from memory.common.db.models import Chunk from memory.common.db.models import Chunk, SourceItem
from memory.common import settings from memory.common import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,3 +68,10 @@ class SearchFilters(TypedDict):
min_confidences: NotRequired[dict[str, float]] min_confidences: NotRequired[dict[str, float]]
observation_types: NotRequired[list[str] | None] observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None] source_ids: NotRequired[list[int] | None]
class SearchConfig(BaseModel):
limit: int = 10
timeout: int = 20
previews: bool = False
useScores: bool = False