more improvements

This commit is contained in:
Daniel O'Connell 2025-06-28 22:14:14 +02:00
parent 1276b83ffb
commit 387bd962e6
6 changed files with 64 additions and 37 deletions

View File

@ -40,8 +40,7 @@ const Search = () => {
setIsLoading(true)
try {
console.log(params)
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
const searchResults = await searchKnowledgeBase(params.query, params.modalities, params.filters, params.config)
setResults(searchResults || [])
} catch (error) {
console.error('Search error:', error)

View File

@ -10,12 +10,17 @@ type Filter = {
[key: string]: any
}
type SearchConfig = {
previews: boolean
useScores: boolean
limit: number
}
export interface SearchParams {
query: string
previews: boolean
modalities: string[]
filters: Filter
limit: number
config: SearchConfig
}
interface SearchFormProps {
@ -40,6 +45,7 @@ const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
const [query, setQuery] = useState('')
const [previews, setPreviews] = useState(false)
const [useScores, setUseScores] = useState(false)
const [modalities, setModalities] = useState<Record<string, boolean>>({})
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
const [tags, setTags] = useState<Record<string, boolean>>({})
@ -68,13 +74,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
onSearch({
query,
previews,
modalities: getSelectedItems(modalities),
config: {
previews,
useScores,
limit
},
filters: {
tags: getSelectedItems(tags),
...cleanFilters(dynamicFilters)
},
limit
})
}
@ -105,6 +114,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
Include content previews
</label>
</div>
<div className="search-option">
<label>
<input
type="checkbox"
checked={useScores}
onChange={(e) => setUseScores(e.target.checked)}
/>
Score results with a LLM before returning
</label>
</div>
<SelectableTags
title="Modalities"

View File

@ -151,13 +151,12 @@ export const useMCP = () => {
return (await mcpCall('get_metadata_schemas'))[0]
}, [mcpCall])
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
const searchKnowledgeBase = useCallback(async (query: string, modalities: string[] = [], filters: Record<string, any> = {}, config: Record<string, any> = {}) => {
return await mcpCall('search_knowledge_base', {
query,
filters,
config,
modalities,
previews,
limit,
})
}, [mcpCall])

View File

@ -6,17 +6,16 @@ import base64
import logging
import pathlib
from datetime import datetime, timezone
from typing import Any
from PIL import Image
from pydantic import BaseModel
from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
from mcp.server.fastmcp.resources.base import Resource
from memory.api.MCP.tools import mcp
from memory.api.search.search import SearchFilters, search
from memory.api.search.search import search
from memory.api.search.types import SearchFilters, SearchConfig
from memory.common import extract, settings
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
from memory.common.celery_app import app as celery_app
@ -80,22 +79,32 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
@mcp.tool()
async def search_knowledge_base(
query: str,
filters: dict[str, Any],
filters: SearchFilters,
config: SearchConfig = SearchConfig(),
modalities: set[str] = set(),
previews: bool = False,
limit: int = 10,
) -> list[dict]:
"""
Search user's stored content including emails, documents, articles, books.
Use to find specific information the user has saved or received.
Combine with search_observations for complete user context.
Use the `get_metadata_schemas` tool to get the metadata schema for each collection, from which you can infer the keys for the filters dictionary.
If you know what kind of data you're looking for, it's worth explicitly filtering by that modality, as this gives better results.
Args:
query: Natural language search query - be descriptive about what you're looking for
previews: Include actual content in results - when false only a snippet is returned
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
filters: Filter by tags, source_ids, etc.
limit: Max results (1-100)
filters: a dictionary with the following keys:
- tags: a list of tags to filter by
- source_ids: a list of source ids to filter by
- min_size: the minimum size of the content to filter by
- max_size: the maximum size of the content to filter by
- min_created_at: the minimum created_at date to filter by
- max_created_at: the maximum created_at date to filter by
config: a dictionary with the following keys:
- limit: the maximum number of results to return
- previews: whether to include the actual content in the results (up to MAX_PREVIEW_LENGTH characters)
- useScores: whether to score the results with a LLM before returning - this results in better results but is slower
Returns: List of search results with id, score, chunks, content, filename
Higher scores (>0.7) indicate strong matches.
@ -112,10 +121,9 @@ async def search_knowledge_base(
upload_data = extract.extract_text(query, skip_summary=True)
results = await search(
upload_data,
previews=previews,
modalities=modalities,
limit=limit,
filters=search_filters,
config=config,
)
return [result.model_dump() for result in results]
@ -211,7 +219,7 @@ async def search_observations(
tags: list[str] | None = None,
observation_types: list[str] | None = None,
min_confidences: dict[str, float] = {},
limit: int = 10,
config: SearchConfig = SearchConfig(),
) -> list[dict]:
"""
Search recorded observations about the user.
@ -224,7 +232,7 @@ async def search_observations(
tags: Filter by tags (must have at least one matching tag)
observation_types: Filter by: belief, preference, behavior, contradiction, general
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
limit: Max results (1-100)
config: SearchConfig
Returns: List with content, tags, created_at, metadata
Results sorted by relevance to your query.
@ -246,9 +254,7 @@ async def search_observations(
extract.DataChunk(data=[semantic_text]),
extract.DataChunk(data=[temporal]),
],
previews=True,
modalities={"semantic", "temporal"},
limit=limit,
filters=SearchFilters(
subject=subject,
min_confidences=min_confidences,
@ -256,7 +262,7 @@ async def search_observations(
observation_types=observation_types,
source_ids=filter_observation_source_ids(tags=tags),
),
timeout=2,
config=config,
)
return [

View File

@ -17,7 +17,7 @@ from memory.api.search import scorer
if settings.ENABLE_BM25_SEARCH:
from memory.api.search.bm25 import search_bm25_chunks
from memory.api.search.types import SearchFilters, SearchResult
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
logger = logging.getLogger(__name__)
@ -57,7 +57,7 @@ async def search_chunks(
async def search_sources(
chunks: list[Chunk], previews: Optional[bool] = False
chunks: list[Chunk], previews: bool = False
) -> list[SearchResult]:
by_source = defaultdict(list)
for chunk in chunks:
@ -73,11 +73,9 @@ async def search_sources(
async def search(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
filters: SearchFilters = {},
timeout: int = 20,
config: SearchConfig = SearchConfig(),
) -> list[SearchResult]:
"""
Search across knowledge base using text query and optional files.
@ -95,13 +93,13 @@ async def search(
chunks = await search_chunks(
data,
allowed_modalities,
limit,
config.limit,
filters,
timeout,
config.timeout,
)
if settings.ENABLE_SEARCH_SCORING:
if settings.ENABLE_SEARCH_SCORING and config.useScores:
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
sources = await search_sources(chunks, previews)
sources = await search_sources(chunks, config.previews)
sources.sort(key=lambda x: x.search_score or 0, reverse=True)
return sources[:limit]
return sources[: config.limit]

View File

@ -2,10 +2,9 @@ from datetime import datetime
import logging
from typing import Optional, TypedDict, NotRequired, cast
from memory.common.db.models.source_item import SourceItem
from pydantic import BaseModel
from memory.common.db.models import Chunk
from memory.common.db.models import Chunk, SourceItem
from memory.common import settings
logger = logging.getLogger(__name__)
@ -69,3 +68,10 @@ class SearchFilters(TypedDict):
min_confidences: NotRequired[dict[str, float]]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]
class SearchConfig(BaseModel):
limit: int = 10
timeout: int = 20
previews: bool = False
useScores: bool = False