mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-29 07:34:43 +02:00
more improvements
This commit is contained in:
parent
1276b83ffb
commit
387bd962e6
@ -40,8 +40,7 @@ const Search = () => {
|
|||||||
|
|
||||||
setIsLoading(true)
|
setIsLoading(true)
|
||||||
try {
|
try {
|
||||||
console.log(params)
|
const searchResults = await searchKnowledgeBase(params.query, params.modalities, params.filters, params.config)
|
||||||
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
|
|
||||||
setResults(searchResults || [])
|
setResults(searchResults || [])
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Search error:', error)
|
console.error('Search error:', error)
|
||||||
|
@ -10,12 +10,17 @@ type Filter = {
|
|||||||
[key: string]: any
|
[key: string]: any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SearchConfig = {
|
||||||
|
previews: boolean
|
||||||
|
useScores: boolean
|
||||||
|
limit: number
|
||||||
|
}
|
||||||
|
|
||||||
export interface SearchParams {
|
export interface SearchParams {
|
||||||
query: string
|
query: string
|
||||||
previews: boolean
|
|
||||||
modalities: string[]
|
modalities: string[]
|
||||||
filters: Filter
|
filters: Filter
|
||||||
limit: number
|
config: SearchConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
interface SearchFormProps {
|
interface SearchFormProps {
|
||||||
@ -40,6 +45,7 @@ const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
|
|||||||
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||||
const [query, setQuery] = useState('')
|
const [query, setQuery] = useState('')
|
||||||
const [previews, setPreviews] = useState(false)
|
const [previews, setPreviews] = useState(false)
|
||||||
|
const [useScores, setUseScores] = useState(false)
|
||||||
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
||||||
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
||||||
const [tags, setTags] = useState<Record<string, boolean>>({})
|
const [tags, setTags] = useState<Record<string, boolean>>({})
|
||||||
@ -68,13 +74,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
|||||||
|
|
||||||
onSearch({
|
onSearch({
|
||||||
query,
|
query,
|
||||||
previews,
|
|
||||||
modalities: getSelectedItems(modalities),
|
modalities: getSelectedItems(modalities),
|
||||||
|
config: {
|
||||||
|
previews,
|
||||||
|
useScores,
|
||||||
|
limit
|
||||||
|
},
|
||||||
filters: {
|
filters: {
|
||||||
tags: getSelectedItems(tags),
|
tags: getSelectedItems(tags),
|
||||||
...cleanFilters(dynamicFilters)
|
...cleanFilters(dynamicFilters)
|
||||||
},
|
},
|
||||||
limit
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,6 +114,16 @@ export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
|||||||
Include content previews
|
Include content previews
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
<div className="search-option">
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={useScores}
|
||||||
|
onChange={(e) => setUseScores(e.target.checked)}
|
||||||
|
/>
|
||||||
|
Score results with a LLM before returning
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
<SelectableTags
|
<SelectableTags
|
||||||
title="Modalities"
|
title="Modalities"
|
||||||
|
@ -151,13 +151,12 @@ export const useMCP = () => {
|
|||||||
return (await mcpCall('get_metadata_schemas'))[0]
|
return (await mcpCall('get_metadata_schemas'))[0]
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
|
const searchKnowledgeBase = useCallback(async (query: string, modalities: string[] = [], filters: Record<string, any> = {}, config: Record<string, any> = {}) => {
|
||||||
return await mcpCall('search_knowledge_base', {
|
return await mcpCall('search_knowledge_base', {
|
||||||
query,
|
query,
|
||||||
filters,
|
filters,
|
||||||
|
config,
|
||||||
modalities,
|
modalities,
|
||||||
previews,
|
|
||||||
limit,
|
|
||||||
})
|
})
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
|
@ -6,17 +6,16 @@ import base64
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
from mcp.server.fastmcp.resources.base import Resource
|
|
||||||
|
|
||||||
from memory.api.MCP.tools import mcp
|
from memory.api.MCP.tools import mcp
|
||||||
from memory.api.search.search import SearchFilters, search
|
from memory.api.search.search import search
|
||||||
|
from memory.api.search.types import SearchFilters, SearchConfig
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||||
from memory.common.celery_app import app as celery_app
|
from memory.common.celery_app import app as celery_app
|
||||||
@ -80,22 +79,32 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int]
|
|||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def search_knowledge_base(
|
async def search_knowledge_base(
|
||||||
query: str,
|
query: str,
|
||||||
filters: dict[str, Any],
|
filters: SearchFilters,
|
||||||
|
config: SearchConfig = SearchConfig(),
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
previews: bool = False,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Search user's stored content including emails, documents, articles, books.
|
Search user's stored content including emails, documents, articles, books.
|
||||||
Use to find specific information the user has saved or received.
|
Use to find specific information the user has saved or received.
|
||||||
Combine with search_observations for complete user context.
|
Combine with search_observations for complete user context.
|
||||||
|
Use the `get_metadata_schemas` tool to get the metadata schema for each collection, from which you can infer the keys for the filters dictionary.
|
||||||
|
|
||||||
|
If you know what kind of data you're looking for, it's worth explicitly filtering by that modality, as this gives better results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Natural language search query - be descriptive about what you're looking for
|
query: Natural language search query - be descriptive about what you're looking for
|
||||||
previews: Include actual content in results - when false only a snippet is returned
|
|
||||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||||
filters: Filter by tags, source_ids, etc.
|
filters: a dictionary with the following keys:
|
||||||
limit: Max results (1-100)
|
- tags: a list of tags to filter by
|
||||||
|
- source_ids: a list of source ids to filter by
|
||||||
|
- min_size: the minimum size of the content to filter by
|
||||||
|
- max_size: the maximum size of the content to filter by
|
||||||
|
- min_created_at: the minimum created_at date to filter by
|
||||||
|
- max_created_at: the maximum created_at date to filter by
|
||||||
|
config: a dictionary with the following keys:
|
||||||
|
- limit: the maximum number of results to return
|
||||||
|
- previews: whether to include the actual content in the results (up to MAX_PREVIEW_LENGTH characters)
|
||||||
|
- useScores: whether to score the results with a LLM before returning - this results in better results but is slower
|
||||||
|
|
||||||
Returns: List of search results with id, score, chunks, content, filename
|
Returns: List of search results with id, score, chunks, content, filename
|
||||||
Higher scores (>0.7) indicate strong matches.
|
Higher scores (>0.7) indicate strong matches.
|
||||||
@ -112,10 +121,9 @@ async def search_knowledge_base(
|
|||||||
upload_data = extract.extract_text(query, skip_summary=True)
|
upload_data = extract.extract_text(query, skip_summary=True)
|
||||||
results = await search(
|
results = await search(
|
||||||
upload_data,
|
upload_data,
|
||||||
previews=previews,
|
|
||||||
modalities=modalities,
|
modalities=modalities,
|
||||||
limit=limit,
|
|
||||||
filters=search_filters,
|
filters=search_filters,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result.model_dump() for result in results]
|
return [result.model_dump() for result in results]
|
||||||
@ -211,7 +219,7 @@ async def search_observations(
|
|||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
observation_types: list[str] | None = None,
|
observation_types: list[str] | None = None,
|
||||||
min_confidences: dict[str, float] = {},
|
min_confidences: dict[str, float] = {},
|
||||||
limit: int = 10,
|
config: SearchConfig = SearchConfig(),
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Search recorded observations about the user.
|
Search recorded observations about the user.
|
||||||
@ -224,7 +232,7 @@ async def search_observations(
|
|||||||
tags: Filter by tags (must have at least one matching tag)
|
tags: Filter by tags (must have at least one matching tag)
|
||||||
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
observation_types: Filter by: belief, preference, behavior, contradiction, general
|
||||||
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
min_confidences: Minimum confidence thresholds, e.g. {"observation_accuracy": 0.8}
|
||||||
limit: Max results (1-100)
|
config: SearchConfig
|
||||||
|
|
||||||
Returns: List with content, tags, created_at, metadata
|
Returns: List with content, tags, created_at, metadata
|
||||||
Results sorted by relevance to your query.
|
Results sorted by relevance to your query.
|
||||||
@ -246,9 +254,7 @@ async def search_observations(
|
|||||||
extract.DataChunk(data=[semantic_text]),
|
extract.DataChunk(data=[semantic_text]),
|
||||||
extract.DataChunk(data=[temporal]),
|
extract.DataChunk(data=[temporal]),
|
||||||
],
|
],
|
||||||
previews=True,
|
|
||||||
modalities={"semantic", "temporal"},
|
modalities={"semantic", "temporal"},
|
||||||
limit=limit,
|
|
||||||
filters=SearchFilters(
|
filters=SearchFilters(
|
||||||
subject=subject,
|
subject=subject,
|
||||||
min_confidences=min_confidences,
|
min_confidences=min_confidences,
|
||||||
@ -256,7 +262,7 @@ async def search_observations(
|
|||||||
observation_types=observation_types,
|
observation_types=observation_types,
|
||||||
source_ids=filter_observation_source_ids(tags=tags),
|
source_ids=filter_observation_source_ids(tags=tags),
|
||||||
),
|
),
|
||||||
timeout=2,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
@ -17,7 +17,7 @@ from memory.api.search import scorer
|
|||||||
if settings.ENABLE_BM25_SEARCH:
|
if settings.ENABLE_BM25_SEARCH:
|
||||||
from memory.api.search.bm25 import search_bm25_chunks
|
from memory.api.search.bm25 import search_bm25_chunks
|
||||||
|
|
||||||
from memory.api.search.types import SearchFilters, SearchResult
|
from memory.api.search.types import SearchConfig, SearchFilters, SearchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ async def search_chunks(
|
|||||||
|
|
||||||
|
|
||||||
async def search_sources(
|
async def search_sources(
|
||||||
chunks: list[Chunk], previews: Optional[bool] = False
|
chunks: list[Chunk], previews: bool = False
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
by_source = defaultdict(list)
|
by_source = defaultdict(list)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
@ -73,11 +73,9 @@ async def search_sources(
|
|||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
previews: Optional[bool] = False,
|
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
limit: int = 10,
|
|
||||||
filters: SearchFilters = {},
|
filters: SearchFilters = {},
|
||||||
timeout: int = 20,
|
config: SearchConfig = SearchConfig(),
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search across knowledge base using text query and optional files.
|
Search across knowledge base using text query and optional files.
|
||||||
@ -95,13 +93,13 @@ async def search(
|
|||||||
chunks = await search_chunks(
|
chunks = await search_chunks(
|
||||||
data,
|
data,
|
||||||
allowed_modalities,
|
allowed_modalities,
|
||||||
limit,
|
config.limit,
|
||||||
filters,
|
filters,
|
||||||
timeout,
|
config.timeout,
|
||||||
)
|
)
|
||||||
if settings.ENABLE_SEARCH_SCORING:
|
if settings.ENABLE_SEARCH_SCORING and config.useScores:
|
||||||
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
chunks = await scorer.rank_chunks(data[0].data[0], chunks, min_score=0.3)
|
||||||
|
|
||||||
sources = await search_sources(chunks, previews)
|
sources = await search_sources(chunks, config.previews)
|
||||||
sources.sort(key=lambda x: x.search_score or 0, reverse=True)
|
sources.sort(key=lambda x: x.search_score or 0, reverse=True)
|
||||||
return sources[:limit]
|
return sources[: config.limit]
|
||||||
|
@ -2,10 +2,9 @@ from datetime import datetime
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, TypedDict, NotRequired, cast
|
from typing import Optional, TypedDict, NotRequired, cast
|
||||||
|
|
||||||
from memory.common.db.models.source_item import SourceItem
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from memory.common.db.models import Chunk
|
from memory.common.db.models import Chunk, SourceItem
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -69,3 +68,10 @@ class SearchFilters(TypedDict):
|
|||||||
min_confidences: NotRequired[dict[str, float]]
|
min_confidences: NotRequired[dict[str, float]]
|
||||||
observation_types: NotRequired[list[str] | None]
|
observation_types: NotRequired[list[str] | None]
|
||||||
source_ids: NotRequired[list[int] | None]
|
source_ids: NotRequired[list[int] | None]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchConfig(BaseModel):
|
||||||
|
limit: int = 10
|
||||||
|
timeout: int = 20
|
||||||
|
previews: bool = False
|
||||||
|
useScores: bool = False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user