@@ -125,7 +124,7 @@ const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
)
}
-const EmailResult = ({ content, tags, metadata }: SearchItem) => {
+export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
return (
{metadata?.title || metadata?.subject || 'Untitled'}
@@ -138,7 +137,7 @@ const EmailResult = ({ content, tags, metadata }: SearchItem) => {
)
}
-const SearchResult = ({ result }: { result: SearchItem }) => {
+export const SearchResult = ({ result }: { result: SearchItem }) => {
if (result.mime_type.startsWith('image/')) {
return
}
@@ -158,86 +157,4 @@ const SearchResult = ({ result }: { result: SearchItem }) => {
return null
}
-const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => {
- if (isLoading) {
- return
- }
- return (
-
- {results.length > 0 && (
-
- Found {results.length} result{results.length !== 1 ? 's' : ''}
-
- )}
-
- {results.map((result, index) =>
)}
-
- {results.length === 0 && (
-
- No results found
-
- )}
-
- )
-}
-
-const SearchForm = ({ isLoading, onSearch }: { isLoading: boolean, onSearch: (query: string) => void }) => {
- const [query, setQuery] = useState('')
- return (
-
- )
-}
-
-const Search = () => {
- const navigate = useNavigate()
- const [results, setResults] = useState([])
- const [isLoading, setIsLoading] = useState(false)
- const { searchKnowledgeBase } = useMCP()
-
- const handleSearch = async (query: string) => {
- if (!query.trim()) return
-
- setIsLoading(true)
- try {
- const searchResults = await searchKnowledgeBase(query)
- setResults(searchResults || [])
- } catch (error) {
- console.error('Search error:', error)
- setResults([])
- } finally {
- setIsLoading(false)
- }
- }
-
- return (
-
-
-
- 🔍 Search Knowledge Base
-
-
-
-
-
- )
-}
-
-export default Search
\ No newline at end of file
+export default SearchResult
\ No newline at end of file
diff --git a/frontend/src/hooks/useAuth.ts b/frontend/src/hooks/useAuth.ts
index 697b930..557689f 100644
--- a/frontend/src/hooks/useAuth.ts
+++ b/frontend/src/hooks/useAuth.ts
@@ -4,20 +4,20 @@ const SERVER_URL = import.meta.env.VITE_SERVER_URL || 'http://localhost:8000'
const SESSION_COOKIE_NAME = import.meta.env.VITE_SESSION_COOKIE_NAME || 'session_id'
// Cookie utilities
-const getCookie = (name) => {
+const getCookie = (name: string) => {
const value = `; ${document.cookie}`
const parts = value.split(`; ${name}=`)
if (parts.length === 2) return parts.pop().split(';').shift()
return null
}
-const setCookie = (name, value, days = 30) => {
+const setCookie = (name: string, value: string, days = 30) => {
const expires = new Date()
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
}
-const deleteCookie = (name) => {
+const deleteCookie = (name: string) => {
document.cookie = `${name}=;expires=Thu, 01 Jan 1970 00:00:01 GMT;path=/`
}
@@ -68,6 +68,7 @@ export const useAuth = () => {
deleteCookie('access_token')
deleteCookie('refresh_token')
deleteCookie(SESSION_COOKIE_NAME)
+ localStorage.removeItem('oauth_client_id')
setIsAuthenticated(false)
}, [])
@@ -110,7 +111,7 @@ export const useAuth = () => {
}, [logout])
// Make authenticated API calls with automatic token refresh
- const apiCall = useCallback(async (endpoint, options = {}) => {
+ const apiCall = useCallback(async (endpoint: string, options: RequestInit = {}) => {
let accessToken = getCookie('access_token')
if (!accessToken) {
@@ -122,7 +123,7 @@ export const useAuth = () => {
'Content-Type': 'application/json',
}
- const requestOptions = {
+ const requestOptions: RequestInit & { headers: Record
} = {
...options,
headers: { ...defaultHeaders, ...options.headers },
}
diff --git a/frontend/src/hooks/useMCP.ts b/frontend/src/hooks/useMCP.ts
index ca2c150..304cc96 100644
--- a/frontend/src/hooks/useMCP.ts
+++ b/frontend/src/hooks/useMCP.ts
@@ -1,5 +1,5 @@
import { useEffect, useCallback } from 'react'
-import { useAuth } from './useAuth'
+import { useAuth } from '@/hooks/useAuth'
const parseServerSentEvents = async (response: Response): Promise => {
const reader = response.body?.getReader()
@@ -91,10 +91,10 @@ const parseJsonRpcResponse = async (response: Response): Promise => {
}
export const useMCP = () => {
- const { apiCall, isAuthenticated, isLoading, checkAuth } = useAuth()
+ const { apiCall, checkAuth } = useAuth()
- const mcpCall = useCallback(async (path: string, method: string, params: any = {}) => {
- const response = await apiCall(`/mcp${path}`, {
+ const mcpCall = useCallback(async (method: string, params: any = {}) => {
+ const response = await apiCall(`/mcp/${method}`, {
method: 'POST',
headers: {
'Accept': 'application/json, text/event-stream',
@@ -118,22 +118,46 @@ export const useMCP = () => {
if (resp?.result?.isError) {
throw new Error(resp?.result?.content[0].text)
}
- return resp?.result?.content.map((item: any) => JSON.parse(item.text))
+ return resp?.result?.content.map((item: any) => {
+ try {
+ return JSON.parse(item.text)
+ } catch (e) {
+ return item.text
+ }
+ })
}, [apiCall])
const listNotes = useCallback(async (path: string = "/") => {
- return await mcpCall('/note_files', 'note_files', { path })
+ return await mcpCall('note_files', { path })
}, [mcpCall])
const fetchFile = useCallback(async (filename: string) => {
- return await mcpCall('/fetch_file', 'fetch_file', { filename })
+ return await mcpCall('fetch_file', { filename })
}, [mcpCall])
- const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10) => {
- return await mcpCall('/search_knowledge_base', 'search_knowledge_base', {
+ const getTags = useCallback(async () => {
+ return await mcpCall('get_all_tags')
+ }, [mcpCall])
+
+ const getSubjects = useCallback(async () => {
+ return await mcpCall('get_all_subjects')
+ }, [mcpCall])
+
+ const getObservationTypes = useCallback(async () => {
+ return await mcpCall('get_all_observation_types')
+ }, [mcpCall])
+
+ const getMetadataSchemas = useCallback(async () => {
+ return (await mcpCall('get_metadata_schemas'))[0]
+ }, [mcpCall])
+
+ const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record = {}, modalities: string[] = []) => {
+ return await mcpCall('search_knowledge_base', {
query,
+ filters,
+ modalities,
previews,
- limit
+ limit,
})
}, [mcpCall])
@@ -146,5 +170,9 @@ export const useMCP = () => {
fetchFile,
listNotes,
searchKnowledgeBase,
+ getTags,
+ getSubjects,
+ getObservationTypes,
+ getMetadataSchemas,
}
}
\ No newline at end of file
diff --git a/frontend/src/hooks/useOAuth.ts b/frontend/src/hooks/useOAuth.ts
index 8f215ff..b5b4491 100644
--- a/frontend/src/hooks/useOAuth.ts
+++ b/frontend/src/hooks/useOAuth.ts
@@ -14,7 +14,7 @@ const generateCodeVerifier = () => {
.replace(/=/g, '')
}
-const generateCodeChallenge = async (verifier) => {
+const generateCodeChallenge = async (verifier: string) => {
const data = new TextEncoder().encode(verifier)
const digest = await crypto.subtle.digest('SHA-256', data)
return btoa(String.fromCharCode(...new Uint8Array(digest)))
@@ -33,7 +33,7 @@ const generateState = () => {
}
// Storage utilities
-const setCookie = (name, value, days = 30) => {
+const setCookie = (name: string, value: string, days = 30) => {
const expires = new Date()
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
diff --git a/frontend/src/main.jsx b/frontend/src/main.jsx
index 3d9da8a..f9bdd5f 100644
--- a/frontend/src/main.jsx
+++ b/frontend/src/main.jsx
@@ -1,6 +1,6 @@
import { StrictMode } from 'react'
import { createRoot } from 'react-dom/client'
-import App from './App.jsx'
+import App from '@/App.jsx'
createRoot(document.getElementById('root')).render(
diff --git a/frontend/src/types/mcp.tsx b/frontend/src/types/mcp.tsx
new file mode 100644
index 0000000..8344a5d
--- /dev/null
+++ b/frontend/src/types/mcp.tsx
@@ -0,0 +1,9 @@
+export type CollectionMetadata = {
+ schema: Record
+ size: number
+}
+
+export type SchemaArg = {
+ type: string
+ description: string
+}
\ No newline at end of file
diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json
new file mode 100644
index 0000000..8503bc8
--- /dev/null
+++ b/frontend/tsconfig.json
@@ -0,0 +1,40 @@
+{
+ "compilerOptions": {
+ "target": "ES2020",
+ "useDefineForClassFields": true,
+ "lib": [
+ "ES2020",
+ "DOM",
+ "DOM.Iterable"
+ ],
+ "module": "ESNext",
+ "skipLibCheck": true,
+ /* Bundler mode */
+ "moduleResolution": "bundler",
+ "allowImportingTsExtensions": true,
+ "resolveJsonModule": true,
+ "isolatedModules": true,
+ "noEmit": true,
+ "jsx": "react-jsx",
+ /* Linting */
+ "strict": true,
+ "noUnusedLocals": true,
+ "noUnusedParameters": true,
+ "noFallthroughCasesInSwitch": true,
+ /* Absolute imports */
+ "baseUrl": ".",
+ "paths": {
+ "@/*": [
+ "src/*"
+ ]
+ }
+ },
+ "include": [
+ "src"
+ ],
+ "references": [
+ {
+ "path": "./tsconfig.node.json"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/frontend/tsconfig.node.json b/frontend/tsconfig.node.json
new file mode 100644
index 0000000..951e593
--- /dev/null
+++ b/frontend/tsconfig.node.json
@@ -0,0 +1,12 @@
+{
+ "compilerOptions": {
+ "composite": true,
+ "skipLibCheck": true,
+ "module": "ESNext",
+ "moduleResolution": "bundler",
+ "allowSyntheticDefaultImports": true
+ },
+ "include": [
+ "vite.config.js"
+ ]
+}
\ No newline at end of file
diff --git a/frontend/vite.config.js b/frontend/vite.config.js
index a7e9c40..51ed537 100644
--- a/frontend/vite.config.js
+++ b/frontend/vite.config.js
@@ -1,8 +1,14 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
+import path from 'path'
// https://vite.dev/config/
export default defineConfig({
plugins: [react()],
base: '/ui/',
+ resolve: {
+ alias: {
+ '@': path.resolve(__dirname, './src'),
+ },
+ },
})
diff --git a/src/memory/api/MCP/__init__.py b/src/memory/api/MCP/__init__.py
index 5ee8067..8711b5d 100644
--- a/src/memory/api/MCP/__init__.py
+++ b/src/memory/api/MCP/__init__.py
@@ -1,2 +1,3 @@
import memory.api.MCP.manifest
import memory.api.MCP.memory
+import memory.api.MCP.metadata
diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py
index 4710afb..440d65f 100644
--- a/src/memory/api/MCP/base.py
+++ b/src/memory/api/MCP/base.py
@@ -127,7 +127,6 @@ async def handle_login(request: Request):
return login_form(request, oauth_params, "Invalid email or password")
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
- print("redirect_url", redirect_url)
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
redirect_url = redirect_url.replace("http://", "cursor://")
return RedirectResponse(url=redirect_url, status_code=302)
diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/memory.py
index 89aaf30..6d63519 100644
--- a/src/memory/api/MCP/memory.py
+++ b/src/memory/api/MCP/memory.py
@@ -3,23 +3,25 @@ MCP tools for the epistemic sparring partner system.
"""
import logging
+import mimetypes
import pathlib
from datetime import datetime, timezone
-import mimetypes
+from typing import Any
from pydantic import BaseModel
-from sqlalchemy import Text, func
+from sqlalchemy import Text
from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY
+from memory.api.MCP.tools import mcp
from memory.api.search.search import SearchFilters, search
from memory.common import extract, settings
+from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
+from memory.common.celery_app import app as celery_app
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
from memory.common.db.connection import make_session
-from memory.common.db.models import AgentObservation, SourceItem
+from memory.common.db.models import SourceItem, AgentObservation
from memory.common.formatters import observation
-from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
-from memory.api.MCP.tools import mcp
logger = logging.getLogger(__name__)
@@ -47,11 +49,13 @@ def filter_observation_source_ids(
return source_ids
-def filter_source_ids(
- modalities: set[str],
- tags: list[str] | None = None,
-):
- if not tags:
+def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] | None:
+ if source_ids := filters.get("source_ids"):
+ return source_ids
+
+ tags = filters.get("tags")
+ size = filters.get("size")
+ if not (tags or size):
return None
with make_session() as session:
@@ -62,6 +66,8 @@ def filter_source_ids(
items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
)
+ if size:
+ items_query = items_query.filter(SourceItem.size == size)
if modalities:
items_query = items_query.filter(SourceItem.modality.in_(modalities))
source_ids = [item.id for item in items_query.all()]
@@ -69,51 +75,12 @@ def filter_source_ids(
return source_ids
-@mcp.tool()
-async def get_all_tags() -> list[str]:
- """
- Get all unique tags used across the entire knowledge base.
- Returns sorted list of tags from both observations and content.
- """
- with make_session() as session:
- tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
- return sorted({row[0] for row in tags_query if row[0] is not None})
-
-
-@mcp.tool()
-async def get_all_subjects() -> list[str]:
- """
- Get all unique subjects from observations about the user.
- Returns sorted list of subject identifiers used in observations.
- """
- with make_session() as session:
- return sorted(
- r.subject for r in session.query(AgentObservation.subject).distinct()
- )
-
-
-@mcp.tool()
-async def get_all_observation_types() -> list[str]:
- """
- Get all observation types that have been used.
- Standard types are belief, preference, behavior, contradiction, general, but there can be more.
- """
- with make_session() as session:
- return sorted(
- {
- r.observation_type
- for r in session.query(AgentObservation.observation_type).distinct()
- if r.observation_type is not None
- }
- )
-
-
@mcp.tool()
async def search_knowledge_base(
query: str,
- previews: bool = False,
+ filters: dict[str, Any],
modalities: set[str] = set(),
- tags: list[str] = [],
+ previews: bool = False,
limit: int = 10,
) -> list[dict]:
"""
@@ -125,7 +92,7 @@ async def search_knowledge_base(
query: Natural language search query - be descriptive about what you're looking for
previews: Include actual content in results - when false only a snippet is returned
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
- tags: Filter by tags - content must have at least one matching tag
+ filters: Filter by tags, source_ids, etc.
limit: Max results (1-100)
Returns: List of search results with id, score, chunks, content, filename
@@ -137,6 +104,9 @@ async def search_knowledge_base(
modalities = set(ALL_COLLECTIONS.keys())
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
+ search_filters = SearchFilters(**filters)
+ search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
+
upload_data = extract.extract_text(query)
results = await search(
upload_data,
@@ -145,10 +115,7 @@ async def search_knowledge_base(
limit=limit,
min_text_score=0.4,
min_multimodal_score=0.25,
- filters=SearchFilters(
- tags=tags,
- source_ids=filter_source_ids(tags=tags, modalities=modalities),
- ),
+ filters=search_filters,
)
return [result.model_dump() for result in results]
diff --git a/src/memory/api/MCP/metadata.py b/src/memory/api/MCP/metadata.py
new file mode 100644
index 0000000..ce54bec
--- /dev/null
+++ b/src/memory/api/MCP/metadata.py
@@ -0,0 +1,119 @@
+import logging
+from collections import defaultdict
+from typing import Annotated, TypedDict, get_args, get_type_hints
+
+from memory.common import qdrant
+from sqlalchemy import func
+
+from memory.api.MCP.tools import mcp
+from memory.common.db.connection import make_session
+from memory.common.db.models import SourceItem
+from memory.common.db.models.source_items import AgentObservation
+
+logger = logging.getLogger(__name__)
+
+
+class SchemaArg(TypedDict):
+ type: str | None
+ description: str | None
+
+
+class CollectionMetadata(TypedDict):
+ schema: dict[str, SchemaArg]
+ size: int
+
+
+def from_annotation(annotation: Annotated) -> SchemaArg | None:
+ try:
+ type_, description = get_args(annotation)
+ type_str = str(type_)
+ if type_str.startswith("typing."):
+ type_str = type_str[7:]
+ elif len((parts := type_str.split("'"))) > 1:
+ type_str = parts[1]
+ return SchemaArg(type=type_str, description=description)
+ except IndexError:
+ logger.error(f"Error from annotation: {annotation}")
+ return None
+
+
+def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
+ if not hasattr(klass, "as_payload"):
+ return {}
+
+ if not (payload_type := get_type_hints(klass.as_payload).get("return")):
+ return {}
+
+ return {
+ name: schema
+ for name, arg in payload_type.__annotations__.items()
+ if (schema := from_annotation(arg))
+ }
+
+
+@mcp.tool()
+async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
+ """Get the metadata schema for each collection used in the knowledge base.
+
+ These schemas can be used to filter the knowledge base.
+
+ Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
+
+ Example:
+ ```
+ {
+ "mail": {"subject": {"type": "str", "description": "The subject of the email."}},
+ "chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
+ }
+ """
+ client = qdrant.get_qdrant_client()
+ sizes = qdrant.get_collection_sizes(client)
+ schemas = defaultdict(dict)
+ for klass in SourceItem.__subclasses__():
+ for collection in klass.get_collections():
+ schemas[collection].update(get_schema(klass))
+
+ return {
+ collection: CollectionMetadata(schema=schema, size=size)
+ for collection, schema in schemas.items()
+ if (size := sizes.get(collection))
+ }
+
+
+@mcp.tool()
+async def get_all_tags() -> list[str]:
+ """Get all unique tags used across the entire knowledge base.
+
+ Returns sorted list of tags from both observations and content.
+ """
+ with make_session() as session:
+ tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
+ return sorted({row[0] for row in tags_query if row[0] is not None})
+
+
+@mcp.tool()
+async def get_all_subjects() -> list[str]:
+ """Get all unique subjects from observations about the user.
+
+ Returns sorted list of subject identifiers used in observations.
+ """
+ with make_session() as session:
+ return sorted(
+ r.subject for r in session.query(AgentObservation.subject).distinct()
+ )
+
+
+@mcp.tool()
+async def get_all_observation_types() -> list[str]:
+ """Get all observation types that have been used.
+
+ Standard types are belief, preference, behavior, contradiction, general, but there can be more.
+ """
+ with make_session() as session:
+ return sorted(
+ {
+ r.observation_type
+ for r in session.query(AgentObservation.observation_type).distinct()
+ if r.observation_type is not None
+ }
+ )
diff --git a/src/memory/api/app.py b/src/memory/api/app.py
index 57ffbd6..2f1974c 100644
--- a/src/memory/api/app.py
+++ b/src/memory/api/app.py
@@ -71,6 +71,13 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
# SQLAdmin setup with OAuth protection
engine = get_engine()
admin = Admin(app, engine)
+admin.app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"], # [settings.SERVER_URL],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
# Setup admin with OAuth protection using existing OAuth provider
setup_admin(admin)
diff --git a/src/memory/api/search/embeddings.py b/src/memory/api/search/embeddings.py
index 6af1d06..c1bfcfe 100644
--- a/src/memory/api/search/embeddings.py
+++ b/src/memory/api/search/embeddings.py
@@ -1,7 +1,7 @@
import base64
import io
import logging
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Optional, cast
import qdrant_client
from PIL import Image
@@ -90,13 +90,71 @@ def query_chunks(
}
+def merge_range_filter(
+ filters: list[dict[str, Any]], key: str, val: Any
+) -> list[dict[str, Any]]:
+ direction, field = key.split("_", maxsplit=1)
+ item = next((f for f in filters if f["key"] == field), None)
+ if not item:
+ item = {"key": field, "range": {}}
+ filters.append(item)
+
+ if direction == "min":
+ item["range"]["gte"] = val
+ elif direction == "max":
+ item["range"]["lte"] = val
+ return filters
+
+
+def merge_filters(
+ filters: list[dict[str, Any]], key: str, val: Any
+) -> list[dict[str, Any]]:
+ if not val and val != 0:
+ return filters
+
+ list_filters = ["tags", "recipients", "observation_types", "authors"]
+ range_filters = [
+ "min_sent_at",
+ "max_sent_at",
+ "min_published",
+ "max_published",
+ "min_size",
+ "max_size",
+ "min_created_at",
+ "max_created_at",
+ ]
+ if key in list_filters:
+ filters.append({"key": key, "match": {"any": val}})
+
+ elif key in range_filters:
+ return merge_range_filter(filters, key, val)
+
+ elif key == "min_confidences":
+ confidence_filters = [
+ {
+ "key": f"confidence.{confidence_type}",
+ "range": {"gte": min_confidence_score},
+ }
+ for confidence_type, min_confidence_score in cast(dict, val).items()
+ ]
+ filters.extend(confidence_filters)
+
+ elif key == "source_ids":
+ filters.append({"key": "id", "match": {"any": val}})
+
+ else:
+ filters.append({"key": key, "match": val})
+
+ return filters
+
+
async def search_embeddings(
data: list[extract.DataChunk],
previews: Optional[bool] = False,
modalities: set[str] = set(),
limit: int = 10,
min_score: float = 0.3,
- filters: SearchFilters = SearchFilters(),
+ filters: SearchFilters = {},
multimodal: bool = False,
) -> list[tuple[SourceData, AnnotatedChunk]]:
"""
@@ -111,27 +169,11 @@ async def search_embeddings(
- filters: Filters to apply to the search results
- multimodal: Whether to search in multimodal collections
"""
- query_filters = {"must": []}
-
- # Handle structured confidence filtering
- if min_confidences := filters.get("min_confidences"):
- confidence_filters = [
- {
- "key": f"confidence.{confidence_type}",
- "range": {"gte": min_confidence_score},
- }
- for confidence_type, min_confidence_score in min_confidences.items()
- ]
- query_filters["must"].extend(confidence_filters)
-
- if tags := filters.get("tags"):
- query_filters["must"].append({"key": "tags", "match": {"any": tags}})
-
- if observation_types := filters.get("observation_types"):
- query_filters["must"].append(
- {"key": "observation_type", "match": {"any": observation_types}}
- )
+ search_filters = []
+ for key, val in filters.items():
+ search_filters = merge_filters(search_filters, key, val)
+ print(search_filters)
client = qdrant.get_qdrant_client()
results = query_chunks(
client,
@@ -140,7 +182,7 @@ async def search_embeddings(
embedding.embed_text if not multimodal else embedding.embed_mixed,
min_score=min_score,
limit=limit,
- filters=query_filters if query_filters["must"] else None,
+ filters={"must": search_filters} if search_filters else None,
)
search_results = {k: results.get(k, []) for k in modalities}
diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py
index 2e07d5d..610d59d 100644
--- a/src/memory/api/search/search.py
+++ b/src/memory/api/search/search.py
@@ -17,6 +17,7 @@ from memory.common.collections import (
MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS,
)
+from memory.common import settings
logger = logging.getLogger(__name__)
@@ -44,51 +45,57 @@ async def search(
- List of search results sorted by score
"""
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
+ print(allowed_modalities)
- text_embeddings_results = with_timeout(
- search_embeddings(
- data,
- previews,
- allowed_modalities & TEXT_COLLECTIONS,
- limit,
- min_text_score,
- filters,
- multimodal=False,
- ),
- timeout,
- )
- multimodal_embeddings_results = with_timeout(
- search_embeddings(
- data,
- previews,
- allowed_modalities & MULTIMODAL_COLLECTIONS,
- limit,
- min_multimodal_score,
- filters,
- multimodal=True,
- ),
- timeout,
- )
- bm25_results = with_timeout(
- search_bm25(
- " ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
- modalities,
- limit=limit,
- filters=filters,
- ),
- timeout,
- )
+ searches = []
+ if settings.ENABLE_EMBEDDING_SEARCH:
+ searches = [
+ with_timeout(
+ search_embeddings(
+ data,
+ previews,
+ allowed_modalities & TEXT_COLLECTIONS,
+ limit,
+ min_text_score,
+ filters,
+ multimodal=False,
+ ),
+ timeout,
+ ),
+ with_timeout(
+ search_embeddings(
+ data,
+ previews,
+ allowed_modalities & MULTIMODAL_COLLECTIONS,
+ limit,
+ min_multimodal_score,
+ filters,
+ multimodal=True,
+ ),
+ timeout,
+ ),
+ ]
+ if settings.ENABLE_BM25_SEARCH:
+ searches.append(
+ with_timeout(
+ search_bm25(
+ " ".join(
+ [c for chunk in data for c in chunk.data if isinstance(c, str)]
+ ),
+ modalities,
+ limit=limit,
+ filters=filters,
+ ),
+ timeout,
+ )
+ )
- results = await asyncio.gather(
- text_embeddings_results,
- multimodal_embeddings_results,
- bm25_results,
- return_exceptions=False,
- )
- text_results, multi_results, bm25_results = results
- all_results = text_results + multi_results
- if len(all_results) < limit:
- all_results += bm25_results
+ search_results = await asyncio.gather(*searches, return_exceptions=False)
+ all_results = []
+ for results in search_results:
+ if len(all_results) >= limit:
+ break
+ all_results.extend(results)
results = group_chunks(all_results, previews or False)
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
diff --git a/src/memory/api/search/utils.py b/src/memory/api/search/utils.py
index 4a9da22..1833235 100644
--- a/src/memory/api/search/utils.py
+++ b/src/memory/api/search/utils.py
@@ -65,9 +65,9 @@ class SearchResult(BaseModel):
class SearchFilters(TypedDict):
- subject: NotRequired[str | None]
+ min_size: NotRequired[int]
+ max_size: NotRequired[int]
min_confidences: NotRequired[dict[str, float]]
- tags: NotRequired[list[str] | None]
observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None]
@@ -115,7 +115,6 @@ def group_chunks(
if isinstance(contents, dict):
tags = contents.pop("tags", [])
content = contents.pop("content", None)
- print(content)
else:
content = contents
contents = {}
diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py
index 8888a8f..62ced5e 100644
--- a/src/memory/common/db/models/__init__.py
+++ b/src/memory/common/db/models/__init__.py
@@ -4,6 +4,7 @@ from memory.common.db.models.source_item import (
SourceItem,
ConfidenceScore,
clean_filename,
+ SourceItemPayload,
)
from memory.common.db.models.source_items import (
MailMessage,
@@ -19,6 +20,14 @@ from memory.common.db.models.source_items import (
Photo,
MiscDoc,
Note,
+ MailMessagePayload,
+ EmailAttachmentPayload,
+ AgentObservationPayload,
+ BlogPostPayload,
+ ComicPayload,
+ BookSectionPayload,
+ NotePayload,
+ ForumPostPayload,
)
from memory.common.db.models.observations import (
ObservationContradiction,
@@ -40,6 +49,18 @@ from memory.common.db.models.users import (
OAuthRefreshToken,
)
+Payload = (
+ SourceItemPayload
+ | AgentObservationPayload
+ | NotePayload
+ | BlogPostPayload
+ | ComicPayload
+ | BookSectionPayload
+ | ForumPostPayload
+ | EmailAttachmentPayload
+ | MailMessagePayload
+)
+
__all__ = [
"Base",
"Chunk",
@@ -75,4 +96,6 @@ __all__ = [
"OAuthClientInformation",
"OAuthState",
"OAuthRefreshToken",
+ # Payloads
+ "Payload",
]
diff --git a/src/memory/common/db/models/source_item.py b/src/memory/common/db/models/source_item.py
index 8ec4a43..d6e68a5 100644
--- a/src/memory/common/db/models/source_item.py
+++ b/src/memory/common/db/models/source_item.py
@@ -4,7 +4,7 @@ Database models for the knowledge base system.
import pathlib
import re
-from typing import Any, Sequence, cast
+from typing import Any, Annotated, Sequence, TypedDict, cast
import uuid
from PIL import Image
@@ -36,6 +36,17 @@ import memory.common.summarizer as summarizer
from memory.common.db.models.base import Base
+class MetadataSchema(TypedDict):
+ type: str
+ description: str
+
+
+class SourceItemPayload(TypedDict):
+ source_id: Annotated[int, "Unique identifier of the source item"]
+ tags: Annotated[list[str], "List of tags for categorization"]
+ size: Annotated[int | None, "Size of the content in bytes"]
+
+
@event.listens_for(Session, "before_flush")
def handle_duplicate_sha256(session, flush_context, instances):
"""
@@ -344,12 +355,17 @@ class SourceItem(Base):
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
- def as_payload(self) -> dict:
- return {
- "source_id": self.id,
- "tags": self.tags,
- "size": self.size,
- }
+ def as_payload(self) -> SourceItemPayload:
+ return SourceItemPayload(
+ source_id=cast(int, self.id),
+ tags=cast(list[str], self.tags),
+ size=cast(int | None, self.size),
+ )
+
+ @classmethod
+ def get_collections(cls) -> list[str]:
+ """Return the list of Qdrant collections this SourceItem type can be stored in."""
+ return [cls.__tablename__]
@property
def display_contents(self) -> str | dict | None:
diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py
index b391b6f..1c9f313 100644
--- a/src/memory/common/db/models/source_items.py
+++ b/src/memory/common/db/models/source_items.py
@@ -5,7 +5,7 @@ Database models for the knowledge base system.
import pathlib
import textwrap
from datetime import datetime
-from typing import Any, Sequence, cast
+from typing import Any, Annotated, Sequence, cast
from PIL import Image
from sqlalchemy import (
@@ -32,11 +32,21 @@ import memory.common.formatters.observation as observation
from memory.common.db.models.source_item import (
SourceItem,
Chunk,
+ SourceItemPayload,
clean_filename,
chunk_mixed,
)
+class MailMessagePayload(SourceItemPayload):
+ message_id: Annotated[str, "Unique email message identifier"]
+ subject: Annotated[str, "Email subject line"]
+ sender: Annotated[str, "Email sender address"]
+ recipients: Annotated[list[str], "List of recipient email addresses"]
+ folder: Annotated[str, "Email folder name"]
+ date: Annotated[str | None, "Email sent date in ISO format"]
+
+
class MailMessage(SourceItem):
__tablename__ = "mail_message"
@@ -80,17 +90,21 @@ class MailMessage(SourceItem):
path.parent.mkdir(parents=True, exist_ok=True)
return path
- def as_payload(self) -> dict:
- return {
- **super().as_payload(),
- "message_id": self.message_id,
- "subject": self.subject,
- "sender": self.sender,
- "recipients": self.recipients,
- "folder": self.folder,
- "tags": self.tags + [self.sender] + self.recipients,
- "date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
+ def as_payload(self) -> MailMessagePayload:
+ base_payload = super().as_payload() | {
+ "tags": cast(list[str], self.tags)
+ + [cast(str, self.sender)]
+ + cast(list[str], self.recipients)
}
+ return MailMessagePayload(
+ **cast(dict, base_payload),
+ message_id=cast(str, self.message_id),
+ subject=cast(str, self.subject),
+ sender=cast(str, self.sender),
+ recipients=cast(list[str], self.recipients),
+ folder=cast(str, self.folder),
+ date=(self.sent_at and self.sent_at.isoformat() or None), # type: ignore
+ )
@property
def parsed_content(self) -> dict[str, Any]:
@@ -152,7 +166,7 @@ class MailMessage(SourceItem):
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
content = self.parsed_content
- chunks = extract.extract_text(cast(str, self.body))
+ chunks = extract.extract_text(cast(str, self.body), modality="email")
def add_header(item: extract.MulitmodalChunk) -> extract.MulitmodalChunk:
if isinstance(item, str):
@@ -163,6 +177,10 @@ class MailMessage(SourceItem):
chunk.data = [add_header(item) for item in chunk.data]
return chunks
+ @classmethod
+ def get_collections(cls) -> list[str]:
+ return ["mail"]
+
# Add indexes
__table_args__ = (
Index("mail_sent_idx", "sent_at"),
@@ -171,6 +189,13 @@ class MailMessage(SourceItem):
)
+class EmailAttachmentPayload(SourceItemPayload):
+ filename: Annotated[str, "Name of the document file"]
+ content_type: Annotated[str, "MIME type of the document"]
+ mail_message_id: Annotated[int, "Associated email message ID"]
+ sent_at: Annotated[str | None, "Document creation timestamp"]
+
+
class EmailAttachment(SourceItem):
__tablename__ = "email_attachment"
@@ -190,17 +215,20 @@ class EmailAttachment(SourceItem):
"polymorphic_identity": "email_attachment",
}
- def as_payload(self) -> dict:
- return {
+ def as_payload(self) -> EmailAttachmentPayload:
+ return EmailAttachmentPayload(
**super().as_payload(),
- "filename": self.filename,
- "content_type": self.mime_type,
- "size": self.size,
- "created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
- "mail_message_id": self.mail_message_id,
- }
+ filename=cast(str, self.filename),
+ content_type=cast(str, self.mime_type),
+ mail_message_id=cast(int, self.mail_message_id),
+ sent_at=(
+ self.mail_message.sent_at
+ and self.mail_message.sent_at.isoformat()
+ or None
+ ), # type: ignore
+ )
- def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
+ def _chunk_contents(self) -> Sequence[extract.DataChunk]:
if cast(str | None, self.filename):
contents = (
settings.FILE_STORAGE_DIR / cast(str, self.filename)
@@ -208,8 +236,7 @@ class EmailAttachment(SourceItem):
else:
contents = cast(str, self.content)
- chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents)
- return [self._make_chunk(c, metadata) for c in chunks]
+ return extract.extract_data_chunks(cast(str, self.mime_type), contents)
@property
def display_contents(self) -> dict:
@@ -221,6 +248,11 @@ class EmailAttachment(SourceItem):
# Add indexes
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
+ @classmethod
+ def get_collections(cls) -> list[str]:
+ """EmailAttachment can go to different collections based on mime_type"""
+ return ["doc", "text", "blog", "photo", "book"]
+
class ChatMessage(SourceItem):
__tablename__ = "chat_message"
@@ -285,6 +317,16 @@ class Photo(SourceItem):
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
+class ComicPayload(SourceItemPayload):
+ title: Annotated[str, "Title of the comic"]
+ author: Annotated[str | None, "Author of the comic"]
+ published: Annotated[str | None, "Publication date in ISO format"]
+ volume: Annotated[str | None, "Volume number"]
+ issue: Annotated[str | None, "Issue number"]
+ page: Annotated[int | None, "Page number"]
+ url: Annotated[str | None, "URL of the comic"]
+
+
class Comic(SourceItem):
__tablename__ = "comic"
@@ -305,18 +347,17 @@ class Comic(SourceItem):
__table_args__ = (Index("comic_author_idx", "author"),)
- def as_payload(self) -> dict:
- payload = {
+ def as_payload(self) -> ComicPayload:
+ return ComicPayload(
**super().as_payload(),
- "title": self.title,
- "author": self.author,
- "published": self.published,
- "volume": self.volume,
- "issue": self.issue,
- "page": self.page,
- "url": self.url,
- }
- return {k: v for k, v in payload.items() if v is not None}
+ title=cast(str, self.title),
+ author=cast(str | None, self.author),
+ published=(self.published and self.published.isoformat() or None), # type: ignore
+ volume=cast(str | None, self.volume),
+ issue=cast(str | None, self.issue),
+ page=cast(int | None, self.page),
+ url=cast(str | None, self.url),
+ )
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename))
@@ -324,6 +365,17 @@ class Comic(SourceItem):
return [extract.DataChunk(data=[image, description])]
+class BookSectionPayload(SourceItemPayload):
+ title: Annotated[str, "Title of the book"]
+ author: Annotated[str | None, "Author of the book"]
+ book_id: Annotated[int, "Unique identifier of the book"]
+ section_title: Annotated[str, "Title of the section"]
+ section_number: Annotated[int, "Number of the section"]
+ section_level: Annotated[int, "Level of the section"]
+ start_page: Annotated[int, "Starting page number"]
+ end_page: Annotated[int, "Ending page number"]
+
+
class BookSection(SourceItem):
"""Individual sections/chapters of books"""
@@ -361,19 +413,22 @@ class BookSection(SourceItem):
Index("book_section_level_idx", "section_level", "section_number"),
)
- def as_payload(self) -> dict:
- vals = {
+ @classmethod
+ def get_collections(cls) -> list[str]:
+ return ["book"]
+
+ def as_payload(self) -> BookSectionPayload:
+ return BookSectionPayload(
**super().as_payload(),
- "title": self.book.title,
- "author": self.book.author,
- "book_id": self.book_id,
- "section_title": self.section_title,
- "section_number": self.section_number,
- "section_level": self.section_level,
- "start_page": self.start_page,
- "end_page": self.end_page,
- }
- return {k: v for k, v in vals.items() if v}
+ title=cast(str, self.book.title),
+ author=cast(str | None, self.book.author),
+ book_id=cast(int, self.book_id),
+ section_title=cast(str, self.section_title),
+ section_number=cast(int, self.section_number),
+ section_level=cast(int, self.section_level),
+ start_page=cast(int, self.start_page),
+ end_page=cast(int, self.end_page),
+ )
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
content = cast(str, self.content.strip())
@@ -397,6 +452,16 @@ class BookSection(SourceItem):
]
+class BlogPostPayload(SourceItemPayload):
+ url: Annotated[str, "URL of the blog post"]
+ title: Annotated[str, "Title of the blog post"]
+ author: Annotated[str | None, "Author of the blog post"]
+ published: Annotated[str | None, "Publication date in ISO format"]
+ description: Annotated[str | None, "Description of the blog post"]
+ domain: Annotated[str | None, "Domain of the blog post"]
+ word_count: Annotated[int | None, "Word count of the blog post"]
+
+
class BlogPost(SourceItem):
__tablename__ = "blog_post"
@@ -428,27 +493,39 @@ class BlogPost(SourceItem):
Index("blog_post_word_count_idx", "word_count"),
)
- def as_payload(self) -> dict:
+ def as_payload(self) -> BlogPostPayload:
published_date = cast(datetime | None, self.published)
metadata = cast(dict | None, self.webpage_metadata) or {}
- payload = {
+ return BlogPostPayload(
**super().as_payload(),
- "url": self.url,
- "title": self.title,
- "author": self.author,
- "published": published_date and published_date.isoformat(),
- "description": self.description,
- "domain": self.domain,
- "word_count": self.word_count,
+ url=cast(str, self.url),
+ title=cast(str, self.title),
+ author=cast(str | None, self.author),
+ published=(published_date and published_date.isoformat() or None), # type: ignore
+ description=cast(str | None, self.description),
+ domain=cast(str | None, self.domain),
+ word_count=cast(int | None, self.word_count),
**metadata,
- }
- return {k: v for k, v in payload.items() if v}
+ )
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
+class ForumPostPayload(SourceItemPayload):
+ url: Annotated[str, "URL of the forum post"]
+ title: Annotated[str, "Title of the forum post"]
+ description: Annotated[str | None, "Description of the forum post"]
+ authors: Annotated[list[str] | None, "Authors of the forum post"]
+ published: Annotated[str | None, "Publication date in ISO format"]
+ slug: Annotated[str | None, "Slug of the forum post"]
+ karma: Annotated[int | None, "Karma score of the forum post"]
+ votes: Annotated[int | None, "Number of votes on the forum post"]
+ score: Annotated[int | None, "Score of the forum post"]
+ comments: Annotated[int | None, "Number of comments on the forum post"]
+
+
class ForumPost(SourceItem):
__tablename__ = "forum_post"
@@ -479,20 +556,20 @@ class ForumPost(SourceItem):
Index("forum_post_title_idx", "title"),
)
- def as_payload(self) -> dict:
- return {
+ def as_payload(self) -> ForumPostPayload:
+ return ForumPostPayload(
**super().as_payload(),
- "url": self.url,
- "title": self.title,
- "description": self.description,
- "authors": self.authors,
- "published_at": self.published_at,
- "slug": self.slug,
- "karma": self.karma,
- "votes": self.votes,
- "score": self.score,
- "comments": self.comments,
- }
+ url=cast(str, self.url),
+ title=cast(str, self.title),
+ description=cast(str | None, self.description),
+ authors=cast(list[str] | None, self.authors),
+ published=(self.published_at and self.published_at.isoformat() or None), # type: ignore
+ slug=cast(str | None, self.slug),
+ karma=cast(int | None, self.karma),
+ votes=cast(int | None, self.votes),
+ score=cast(int | None, self.score),
+ comments=cast(int | None, self.comments),
+ )
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
@@ -545,6 +622,12 @@ class GithubItem(SourceItem):
)
+class NotePayload(SourceItemPayload):
+ note_type: Annotated[str | None, "Category of the note"]
+ subject: Annotated[str | None, "What the note is about"]
+ confidence: Annotated[dict[str, float], "Confidence scores for the note"]
+
+
class Note(SourceItem):
"""A quick note of something of interest."""
@@ -565,13 +648,13 @@ class Note(SourceItem):
Index("note_subject_idx", "subject"),
)
- def as_payload(self) -> dict:
- return {
+ def as_payload(self) -> NotePayload:
+ return NotePayload(
**super().as_payload(),
- "note_type": self.note_type,
- "subject": self.subject,
- "confidence": self.confidence_dict,
- }
+ note_type=cast(str | None, self.note_type),
+ subject=cast(str | None, self.subject),
+ confidence=self.confidence_dict,
+ )
@property
def display_contents(self) -> dict:
@@ -602,6 +685,19 @@ class Note(SourceItem):
self.as_text(cast(str, self.content), cast(str | None, self.subject))
)
+ @classmethod
+ def get_collections(cls) -> list[str]:
+ return ["text"] # Notes go to the text collection
+
+
+class AgentObservationPayload(SourceItemPayload):
+ session_id: Annotated[str | None, "Session ID for the observation"]
+ observation_type: Annotated[str, "Type of observation"]
+ subject: Annotated[str, "What/who the observation is about"]
+ confidence: Annotated[dict[str, float], "Confidence scores for the observation"]
+ evidence: Annotated[dict | None, "Supporting context, quotes, etc."]
+ agent_model: Annotated[str, "Which AI model made this observation"]
+
class AgentObservation(SourceItem):
"""
@@ -652,18 +748,16 @@ class AgentObservation(SourceItem):
kwargs["modality"] = "observation"
super().__init__(**kwargs)
- def as_payload(self) -> dict:
- payload = {
+ def as_payload(self) -> AgentObservationPayload:
+ return AgentObservationPayload(
**super().as_payload(),
- "observation_type": self.observation_type,
- "subject": self.subject,
- "confidence": self.confidence_dict,
- "evidence": self.evidence,
- "agent_model": self.agent_model,
- }
- if self.session_id is not None:
- payload["session_id"] = str(self.session_id)
- return payload
+ observation_type=cast(str, self.observation_type),
+ subject=cast(str, self.subject),
+ confidence=self.confidence_dict,
+ evidence=cast(dict | None, self.evidence),
+ agent_model=cast(str, self.agent_model),
+ session_id=cast(str | None, self.session_id) and str(self.session_id),
+ )
@property
def all_contradictions(self):
@@ -759,3 +853,7 @@ class AgentObservation(SourceItem):
# ))
return chunks
+
+ @classmethod
+ def get_collections(cls) -> list[str]:
+ return ["semantic", "temporal"]
diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py
index 3df07df..2fbe50a 100644
--- a/src/memory/common/extract.py
+++ b/src/memory/common/extract.py
@@ -53,7 +53,9 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
return image
-def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
+def doc_to_images(
+ content: bytes | str | pathlib.Path, modality: str = "doc"
+) -> list[DataChunk]:
with as_file(content) as file_path:
with pymupdf.open(file_path) as pdf:
return [
@@ -65,6 +67,7 @@ def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
"height": page.rect.height,
},
mime_type="image/jpeg",
+ modality=modality,
)
for page in pdf.pages()
]
@@ -122,6 +125,7 @@ def extract_text(
content: bytes | str | pathlib.Path,
chunk_size: int | None = None,
metadata: dict[str, Any] = {},
+ modality: str = "text",
) -> list[DataChunk]:
if isinstance(content, pathlib.Path):
content = content.read_text()
@@ -130,7 +134,7 @@ def extract_text(
content = cast(str, content)
chunks = [
- DataChunk(data=[c], modality="text", metadata=metadata)
+ DataChunk(data=[c], modality=modality, metadata=metadata)
for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
]
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
@@ -139,7 +143,7 @@ def extract_text(
DataChunk(
data=[summary],
metadata=merge_metadata(metadata, {"tags": tags}),
- modality="text",
+ modality=modality,
)
)
return chunks
@@ -158,9 +162,7 @@ def extract_data_chunks(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/msword",
]:
- logger.info(f"Extracting content from {content}")
chunks = extract_docx(content)
- logger.info(f"Extracted {len(chunks)} pages from {content}")
elif mime_type.startswith("text/"):
chunks = extract_text(content, chunk_size)
elif mime_type.startswith("image/"):
diff --git a/src/memory/common/qdrant.py b/src/memory/common/qdrant.py
index 5e0121e..5e6ad8f 100644
--- a/src/memory/common/qdrant.py
+++ b/src/memory/common/qdrant.py
@@ -224,6 +224,15 @@ def get_collection_info(
return info.model_dump()
+def get_collection_sizes(client: qdrant_client.QdrantClient) -> dict[str, int]:
+ """Get the size of each collection."""
+ collections = [i.name for i in client.get_collections().collections]
+ return {
+ collection_name: client.count(collection_name).count # type: ignore
+ for collection_name in collections
+ }
+
+
def batch_ids(
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
) -> Generator[list[str], None, None]:
diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py
index 9de24f1..652faf0 100644
--- a/src/memory/common/settings.py
+++ b/src/memory/common/settings.py
@@ -6,6 +6,8 @@ load_dotenv()
def boolean_env(key: str, default: bool = False) -> bool:
+ if key not in os.environ:
+ return default
return os.getenv(key, "0").lower() in ("1", "true", "yes")
@@ -130,6 +132,10 @@ else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
+# Search settings
+ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
+ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
+
# API settings
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
HTTPS = boolean_env("HTTPS", False)
diff --git a/tests/memory/api/search/test_search_embeddings.py b/tests/memory/api/search/test_search_embeddings.py
new file mode 100644
index 0000000..ca9a6a8
--- /dev/null
+++ b/tests/memory/api/search/test_search_embeddings.py
@@ -0,0 +1,174 @@
+import pytest
+from memory.api.search.embeddings import merge_range_filter, merge_filters
+
+
+def test_merge_range_filter_new_filter():
+ """Test adding new range filters"""
+ filters = []
+ result = merge_range_filter(filters, "min_size", 100)
+ assert result == [{"key": "size", "range": {"gte": 100}}]
+
+ filters = []
+ result = merge_range_filter(filters, "max_size", 1000)
+ assert result == [{"key": "size", "range": {"lte": 1000}}]
+
+
+def test_merge_range_filter_existing_field():
+ """Test adding to existing field"""
+ filters = [{"key": "size", "range": {"lte": 1000}}]
+ result = merge_range_filter(filters, "min_size", 100)
+ assert result == [{"key": "size", "range": {"lte": 1000, "gte": 100}}]
+
+
+def test_merge_range_filter_override_existing():
+ """Test overriding existing values"""
+ filters = [{"key": "size", "range": {"gte": 100}}]
+ result = merge_range_filter(filters, "min_size", 200)
+ assert result == [{"key": "size", "range": {"gte": 200}}]
+
+
+def test_merge_range_filter_with_other_filters():
+ """Test adding range filter alongside other filters"""
+ filters = [{"key": "tags", "match": {"any": ["tag1"]}}]
+ result = merge_range_filter(filters, "min_size", 100)
+
+ expected = [
+ {"key": "tags", "match": {"any": ["tag1"]}},
+ {"key": "size", "range": {"gte": 100}},
+ ]
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ "key,expected_direction,expected_field",
+ [
+ ("min_sent_at", "min", "sent_at"),
+ ("max_sent_at", "max", "sent_at"),
+ ("min_published", "min", "published"),
+ ("max_published", "max", "published"),
+ ("min_size", "min", "size"),
+ ("max_size", "max", "size"),
+ ],
+)
+def test_merge_range_filter_key_parsing(key, expected_direction, expected_field):
+ """Test that field names are correctly extracted from keys"""
+ filters = []
+ result = merge_range_filter(filters, key, 100)
+
+ assert len(result) == 1
+ assert result[0]["key"] == expected_field
+ range_key = "gte" if expected_direction == "min" else "lte"
+ assert result[0]["range"][range_key] == 100
+
+
+@pytest.mark.parametrize(
+ "filter_key,filter_value",
+ [
+ ("tags", ["tag1", "tag2"]),
+ ("recipients", ["user1", "user2"]),
+ ("observation_types", ["belief", "preference"]),
+ ("authors", ["author1"]),
+ ],
+)
+def test_merge_filters_list_filters(filter_key, filter_value):
+ """Test list filters that use match any"""
+ filters = []
+ result = merge_filters(filters, filter_key, filter_value)
+ expected = [{"key": filter_key, "match": {"any": filter_value}}]
+ assert result == expected
+
+
+def test_merge_filters_min_confidences():
+ """Test min_confidences filter creates multiple range conditions"""
+ filters = []
+ confidences = {"observation_accuracy": 0.8, "source_reliability": 0.9}
+ result = merge_filters(filters, "min_confidences", confidences)
+
+ expected = [
+ {"key": "confidence.observation_accuracy", "range": {"gte": 0.8}},
+ {"key": "confidence.source_reliability", "range": {"gte": 0.9}},
+ ]
+ assert result == expected
+
+
+def test_merge_filters_source_ids():
+ """Test source_ids filter maps to id field"""
+ filters = []
+ result = merge_filters(filters, "source_ids", ["id1", "id2"])
+ expected = [{"key": "id", "match": {"any": ["id1", "id2"]}}]
+ assert result == expected
+
+
+def test_merge_filters_range_delegation():
+ """Test range filters are properly delegated to merge_range_filter"""
+ filters = []
+ result = merge_filters(filters, "min_size", 100)
+
+ assert len(result) == 1
+ assert "range" in result[0]
+ assert result[0]["range"]["gte"] == 100
+
+
+def test_merge_filters_combined_range():
+ """Test that min/max range pairs merge into single filter"""
+ filters = []
+ filters = merge_filters(filters, "min_size", 100)
+ filters = merge_filters(filters, "max_size", 1000)
+
+ size_filters = [f for f in filters if f["key"] == "size"]
+ assert len(size_filters) == 1
+ assert size_filters[0]["range"]["gte"] == 100
+ assert size_filters[0]["range"]["lte"] == 1000
+
+
+def test_merge_filters_preserves_existing():
+ """Test that existing filters are preserved when adding new ones"""
+ existing_filters = [{"key": "existing", "match": "value"}]
+ result = merge_filters(existing_filters, "tags", ["new_tag"])
+
+ assert len(result) == 2
+ assert {"key": "existing", "match": "value"} in result
+ assert {"key": "tags", "match": {"any": ["new_tag"]}} in result
+
+
+def test_merge_filters_realistic_combination():
+ """Test a realistic filter combination for knowledge base search"""
+ filters = []
+
+ # Add typical knowledge base filters
+ filters = merge_filters(filters, "tags", ["important", "work"])
+ filters = merge_filters(filters, "min_published", "2023-01-01")
+ filters = merge_filters(filters, "max_size", 1000000) # 1MB max
+ filters = merge_filters(filters, "min_confidences", {"observation_accuracy": 0.8})
+
+ assert len(filters) == 4
+
+ # Check each filter type
+ tag_filter = next(f for f in filters if f["key"] == "tags")
+ assert tag_filter["match"]["any"] == ["important", "work"]
+
+ published_filter = next(f for f in filters if f["key"] == "published")
+ assert published_filter["range"]["gte"] == "2023-01-01"
+
+ size_filter = next(f for f in filters if f["key"] == "size")
+ assert size_filter["range"]["lte"] == 1000000
+
+ confidence_filter = next(
+ f for f in filters if f["key"] == "confidence.observation_accuracy"
+ )
+ assert confidence_filter["range"]["gte"] == 0.8
+
+
+def test_merge_filters_unknown_key():
+ """Test fallback behavior for unknown filter keys"""
+ filters = []
+ result = merge_filters(filters, "unknown_field", "unknown_value")
+ expected = [{"key": "unknown_field", "match": "unknown_value"}]
+ assert result == expected
+
+
+def test_merge_filters_empty_min_confidences():
+ """Test min_confidences with empty dict does nothing"""
+ filters = []
+ result = merge_filters(filters, "min_confidences", {})
+ assert result == []