mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 23:24:43 +02:00
more improvements
This commit is contained in:
parent
1276b83ffb
commit
387bd962e6
@ -40,8 +40,7 @@ const Search = () => {
|
||||
|
||||
setIsLoading(true)
|
||||
try {
|
||||
console.log(params)
|
||||
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
|
||||
const searchResults = await searchKnowledgeBase(params.query, params.modalities, params.filters, params.config)
|
||||
setResults(searchResults || [])
|
||||
} catch (error) {
|
||||
console.error('Search error:', error)
|
||||
|
@ -10,12 +10,17 @@ type Filter = {
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
type SearchConfig = {
|
||||
previews: boolean
|
||||
useScores: boolean
|
||||
limit: number
|
||||
}
|
||||
|
||||
export interface SearchParams {
|
||||
query: string
|
||||
previews: boolean
|
||||
modalities: string[]
|
||||
filters: Filter
|
||||
limit: number
|
||||
config: SearchConfig
|
||||
}
|
||||
|
||||
interface SearchFormProps {
|
||||
@ -40,6 +45,7 @@ const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
|
||||
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||
const [query, setQuery] = useState('')
|
||||
const [previews, setPreviews] = useState(false)
|
||||
const [useScores, setUseScores] = useState(false)
|
||||
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
||||
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
||||
const [tags, setTags] = useState<Record<string, boolean>>({})
|
||||
@ -68,13 +74,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||
|
||||
onSearch({
|
||||
query,
|
||||
previews,
|
||||
modalities: getSelectedItems(modalities),
|
||||
config: {
|
||||
previews,
|
||||
useScores,
|
||||
limit
|
||||
},
|
||||
filters: {
|
||||
tags: getSelectedItems(tags),
|
||||
...cleanFilters(dynamicFilters)
|
||||
},
|
||||
limit
|
||||
})
|
||||
}
|
||||
|
||||
@ -105,6 +114,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||
Include content previews
|
||||
</label>
|
||||
</div>
|
||||
<div className="search-option">
|
||||
<label>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={useScores}
|
||||
onChange={(e) => setUseScores(e.target.checked)}
|
||||
/>
|
||||
Score results with a LLM before returning
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<SelectableTags
|
||||
title="Modalities"
|
||||
|
@ -151,13 +151,12 @@ export const useMCP = () => {
|
||||
return (await mcpCall('get_metadata_schemas'))[0]
|
||||
}, [mcpCall])
|
||||
|
||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
|
||||
const searchKnowledgeBase = useCallback(async (query: string, modalities: string[] = [], filters: Record<string, any> = {}, config: Record<string, any> = {}) => {
|
||||
return await mcpCall('search_knowledge_base', {
|
||||
query,
|
||||
filters,
|
||||
config,
|
||||
modalities,
|
||||
previews,
|
||||
limit,
|
||||
})
|
||||
}, [mcpCall])
|
||||
|
||||
|
@ -6,17 +6,16 @@ import base64
|
||||
import logging
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from mcp.server.fastmcp.resources.base import Resource
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.api.search.search import SearchFilters, search
|
||||
from memory.api.search.search import search
|
||||
from memory.api.search.types import SearchFilters, SearchConfig
|
||||
from memory.common import extract, settings
|
||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||
from memory.common.celery_app import app as celery_app
|
||||
@ -80,22 +79,32 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
filters: dict[str, Any],
|
||||
filters: SearchFilters,
|
||||
config: SearchConfig = SearchConfig(),
|
||||
modalities: set[str] = set(),
|
||||
previews: bool = False,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search user's stored content including emails, documents, articles, books.
|
||||
Use to find specific information the user has saved or received.
|
||||
Combine with search_observations for complete user context.
|
||||
Use the `get_metadata_schemas` tool to get the metadata schema for each collection, from which you can infer the keys for the filters dictionary.
|
||||
|
||||
If you know what kind of data you're looking for, it's worth explicitly filtering by that modality, as this gives better results.
|
||||
|
||||
Args:
|
||||
query: Natural language search query - be descriptive about what you're looking for
|
||||
previews: Include actual content in results - when false only a snippet is returned
|
||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||
filters: Filter by tags, source_ids, etc.
|
||||
limit: Max results (1-100)
|
||||
filters: a dictionary with the following keys:
|
||||
- tags: a list of tags to filter by
|
||||
- source_ids: a list of source ids to filter by
|
||||
- min_size: the minimum size of the content to filter by
|
||||
- max_size: the maximum size of the content to filter by
|
||||
- min_created_at: the minimum created_at date to filter by
|
||||
- max_created_at: the maximum created_at date to filter by
|
||||
config: a dictionary with the following keys:
|
||||
- limit: the maximum number of results to return
|
||||
- previews: whether to include the actual content in the results (up to MAX_PREVIEW_LENGTH characters)
|
||||
- useScores: whether to score the results with a LLM before returning - this results in better results but is slower
|
||||
|
||||
Returns: List of search results with id, score, chunks, content, filename
|
||||
Higher scores (>0.7) indicate strong matches.
|
||||
@ -112,10 +121,9 @@ async def search_knowledge_base(
|
||||
upload_data = extract.extract_text(query, skip_summary=True)
|
||||
results = await search(
|
||||
upload_data,
|
||||
previews=previews,
|
||||
modalities=modalities,
|
||||
limit=limit,
|
||||
filters=search_filters,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return [result.model_dump() for result in results]
|
||||
@ -211,7 +219,7 @@ async def search_observations(
|
||||
tags: list[str] | None = None,
|
||||
observation_types: list[str] | None = None,
|
||||
min_confidences: dict[str, float] = {},
|
||||
limit: int = 10,
|
||||
config: SearchConfig = SearchConfig(),
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search recorded observations about the user.
|
||||
@ -224,7 +232,7 @@ async def search_observations(
|
||||
tags: Filter by tags (must have at least one matching tag)
|
||||
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
||||
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
||||
limit: Max results (1-100)
|
||||
config: SearchConfig
|
||||
|
||||
Returns: List with content, tags, created_at, metadata
|
||||
Results sorted by relevance to your query.
|
||||
@ -246,9 +254,7 @@ async def search_observations(
|
||||
extract.DataChunk(data=[semantic_text]),
|
||||
extract.DataChunk(data=[temporal]),
|
||||
],
|
||||
previews=True,
|
||||
modalities={"semantic", "temporal"},
|
||||
limit=limit,
|
||||
filters=SearchFilters(
|
||||
subject=subject,
|
||||
min_confidences=min_confidences,
|
||||
@ -256,7 +262,7 @@ async def search_observations(
|
||||
observation_types=observation_types,
|
||||
source_ids=filter_observation_source_ids(tags=tags),
|
||||
),
|
||||
timeout=2,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return [
|
||||
|
@ -17,7 +17,7 @@ from memory.api.search import scorer
|
||||
if settings.ENABLE_BM25_SEARCH:
|
||||
from memory.api.search.bm25 import search_bm25_chunks
|
||||
|
||||
from memory.api.search.types import SearchFilters, SearchResult
|
||||
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,7 +57,7 @@ async def search_chunks(
|
||||
|
||||
|
||||
async def search_sources(
|
||||
chunks: list[Chunk], previews: Optional[bool] = False
|
||||
chunks: list[Chunk], previews: bool = False
|
||||
) -> list[SearchResult]:
|
||||
by_source = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
@ -73,11 +73,9 @@ async def search_sources(
|
||||
|
||||
async def search(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
filters: SearchFilters = {},
|
||||
timeout: int = 20,
|
||||
config: SearchConfig = SearchConfig(),
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search across knowledge base using text query and optional files.
|
||||
@ -95,13 +93,13 @@ async def search(
|
||||
chunks = await search_chunks(
|
||||
data,
|
||||
allowed_modalities,
|
||||
limit,
|
||||
config.limit,
|
||||
filters,
|
||||
timeout,
|
||||
config.timeout,
|
||||
)
|
||||
if settings.ENABLE_SEARCH_SCORING:
|
||||
if settings.ENABLE_SEARCH_SCORING and config.useScores:
|
||||
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||
|
||||
sources = await search_sources(chunks, previews)
|
||||
sources = await search_sources(chunks, config.previews)
|
||||
sources.sort(key=lambda x: x.search_score or 0, reverse=True)
|
||||
return sources[:limit]
|
||||
return sources[: config.limit]
|
||||
|
@ -2,10 +2,9 @@ from datetime import datetime
|
||||
import logging
|
||||
from typing import Optional, TypedDict, NotRequired, cast
|
||||
|
||||
from memory.common.db.models.source_item import SourceItem
|
||||
from pydantic import BaseModel
|
||||
|
||||
from memory.common.db.models import Chunk
|
||||
from memory.common.db.models import Chunk, SourceItem
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -69,3 +68,10 @@ class SearchFilters(TypedDict):
|
||||
min_confidences: NotRequired[dict[str, float]]
|
||||
observation_types: NotRequired[list[str] | None]
|
||||
source_ids: NotRequired[list[int] | None]
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
limit: int = 10
|
||||
timeout: int = 20
|
||||
previews: bool = False
|
||||
useScores: bool = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user