diff --git a/frontend/src/App.css b/frontend/src/App.css index 32ff733..193eabe 100644 --- a/frontend/src/App.css +++ b/frontend/src/App.css @@ -285,6 +285,7 @@ body { display: flex; gap: 1rem; align-items: center; + margin-bottom: 1.5rem; } .search-input { @@ -323,6 +324,34 @@ body { cursor: not-allowed; } +.search-options { + border-top: 1px solid #e2e8f0; + padding-top: 1.5rem; + display: grid; + gap: 1.5rem; +} + +.search-option { + display: flex; + flex-direction: column; + gap: 0.5rem; +} + +.search-option label { + font-weight: 500; + color: #4a5568; + font-size: 0.9rem; + display: flex; + align-items: center; + gap: 0.5rem; +} + +.search-option input[type="checkbox"] { + margin-right: 0.5rem; + transform: scale(1.1); + accent-color: #667eea; +} + /* Search Results */ .search-results { margin-top: 2rem; @@ -429,6 +458,11 @@ body { transition: background-color 0.2s; } +.tag.selected { + background: #667eea; + color: white; +} + .tag:hover { background: #e2e8f0; } @@ -495,6 +529,14 @@ body { padding: 1.5rem; } + .search-options { + gap: 1rem; + } + + .limit-input { + width: 80px; + } + .search-result-card { padding: 1rem; } @@ -689,4 +731,194 @@ body { .markdown-preview a:hover { color: #2b6cb0; +} + +/* Dynamic filters styles */ +.modality-filters { + display: flex; + flex-direction: column; + gap: 1rem; + margin-top: 0.5rem; +} + +.modality-filter-group { + border: 1px solid #e2e8f0; + border-radius: 8px; + background: white; +} + +.modality-filter-title { + padding: 0.75rem 1rem; + font-weight: 600; + color: #4a5568; + background: #f8fafc; + border-radius: 7px 7px 0 0; + cursor: pointer; + user-select: none; + transition: background-color 0.2s; +} + +.modality-filter-title:hover { + background: #edf2f7; +} + +.modality-filter-group[open] .modality-filter-title { + border-bottom: 1px solid #e2e8f0; + border-radius: 7px 7px 0 0; +} + +.filters-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); + gap: 1rem; + padding: 1rem; +} + +.filter-field { + display: flex; + flex-direction: column; + gap: 0.25rem; +} + +.filter-label { + font-size: 0.875rem; + font-weight: 500; + color: #4a5568; + margin-bottom: 0.25rem; +} + +.filter-input { + padding: 0.5rem; + border: 1px solid #e2e8f0; + border-radius: 4px; + font-size: 0.875rem; + background: white; + transition: border-color 0.2s; +} + +.filter-input:focus { + outline: none; + border-color: #667eea; + box-shadow: 0 0 0 1px #667eea; +} + +/* Selectable tags controls */ +.selectable-tags-details { + border: 1px solid #e2e8f0; + border-radius: 8px; + background: white; +} + +.selectable-tags-summary { + padding: 0.75rem 1rem; + font-weight: 600; + color: #4a5568; + cursor: pointer; + user-select: none; + transition: background-color 0.2s; + border-radius: 7px; +} + +.selectable-tags-summary:hover { + background: #f8fafc; +} + +.selectable-tags-details[open] .selectable-tags-summary { + border-bottom: 1px solid #e2e8f0; + border-radius: 7px 7px 0 0; + background: #f8fafc; +} + +.tag-controls { + display: flex; + align-items: center; + gap: 0.5rem; + padding: 0.75rem 1rem; + border-bottom: 1px solid #e2e8f0; + background: #fafbfc; +} + +.tag-control-btn { + background: #f7fafc; + border: 1px solid #e2e8f0; + border-radius: 4px; + padding: 0.25rem 0.5rem; + font-size: 0.75rem; + color: #4a5568; + cursor: pointer; + transition: all 0.2s; +} + +.tag-control-btn:hover:not(:disabled) { + background: #edf2f7; + border-color: #cbd5e0; +} + +.tag-control-btn:disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.tag-count { + font-size: 0.75rem; + color: #718096; + font-weight: 500; +} + +/* Tag search controls */ +.tag-search-controls { + display: flex; + align-items: center; + gap: 1rem; + padding: 0.75rem 1rem; + background: #f8fafc; + border-bottom: 1px solid #e2e8f0; +} + +.tag-search-input { + flex: 1; + padding: 0.375rem 0.5rem; + border: 1px solid #e2e8f0; + border-radius: 4px; + font-size: 0.875rem; + background: white; + transition: border-color 0.2s; +} + +.tag-search-input:focus { + outline: none; + border-color: #667eea; + box-shadow: 0 0 0 1px #667eea; +} + + + +.filtered-count { + font-size: 0.75rem; + color: #718096; + font-weight: 500; + white-space: nowrap; +} + +.tags-display-area { + padding: 1rem; +} + +@media (max-width: 768px) { + .filters-grid { + grid-template-columns: 1fr; + gap: 0.5rem; + padding: 0.75rem; + } + + .modality-filter-title { + padding: 0.5rem 0.75rem; + font-size: 0.9rem; + } + + .tag-search-controls { + flex-direction: column; + align-items: stretch; + gap: 0.5rem; + } } \ No newline at end of file diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index e7df131..3b28c2b 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -2,9 +2,9 @@ import { useEffect } from 'react' import { BrowserRouter as Router, Routes, Route, Navigate, useNavigate, useLocation } from 'react-router-dom' import './App.css' -import { useAuth } from './hooks/useAuth' -import { useOAuth } from './hooks/useOAuth' -import { Loading, LoginPrompt, AuthError, Dashboard, Search } from './components' +import { useAuth } from '@/hooks/useAuth' +import { useOAuth } from '@/hooks/useOAuth' +import { Loading, LoginPrompt, AuthError, Dashboard, Search } from '@/components' // AuthWrapper handles redirects based on auth state const AuthWrapper = () => { @@ -31,7 +31,10 @@ const AuthWrapper = () => { // Handle redirects based on auth state changes useEffect(() => { if (!isLoading) { - if (isAuthenticated) { + if (location.pathname === '/ui/logout') { + logout() + navigate('/ui/login', { replace: true }) + } else if (isAuthenticated) { // If authenticated and on login page, redirect to dashboard if (location.pathname === '/ui/login' || location.pathname === '/ui') { navigate('/ui/dashboard', { replace: true }) diff --git a/frontend/src/components/Dashboard.jsx b/frontend/src/components/Dashboard.jsx index 3b2e55d..6a7d938 100644 --- a/frontend/src/components/Dashboard.jsx +++ b/frontend/src/components/Dashboard.jsx @@ -1,5 +1,5 @@ import { Link } from 'react-router-dom' -import { useMCP } from '../hooks/useMCP' +import { useMCP } from '@/hooks/useMCP' const Dashboard = ({ onLogout }) => { const { listNotes } = useMCP() diff --git a/frontend/src/components/Loading.tsx b/frontend/src/components/Loading.tsx index 76e8e2d..0222344 100644 --- a/frontend/src/components/Loading.tsx +++ b/frontend/src/components/Loading.tsx @@ -1,6 +1,4 @@ -import React from 'react' - -const Loading = ({ message = "Loading..." }) => { +const Loading = ({ message = "Loading..." }: { message?: string }) => { return (

{message}

diff --git a/frontend/src/components/auth/AuthError.tsx b/frontend/src/components/auth/AuthError.tsx index ab46755..65d201a 100644 --- a/frontend/src/components/auth/AuthError.tsx +++ b/frontend/src/components/auth/AuthError.tsx @@ -1,6 +1,4 @@ -import React from 'react' - -const AuthError = ({ error, onRetry }) => { +const AuthError = ({ error, onRetry }: { error: string, onRetry: () => void }) => { return (

Authentication Error

diff --git a/frontend/src/components/auth/LoginPrompt.tsx b/frontend/src/components/auth/LoginPrompt.tsx index f3d0f1e..fbd2020 100644 --- a/frontend/src/components/auth/LoginPrompt.tsx +++ b/frontend/src/components/auth/LoginPrompt.tsx @@ -1,6 +1,4 @@ -import React from 'react' - -const LoginPrompt = ({ onLogin }) => { +const LoginPrompt = ({ onLogin }: { onLogin: () => void }) => { return (

Memory App

diff --git a/frontend/src/components/index.js b/frontend/src/components/index.js index b21dc34..546fb01 100644 --- a/frontend/src/components/index.js +++ b/frontend/src/components/index.js @@ -1,5 +1,5 @@ export { default as Loading } from './Loading' export { default as Dashboard } from './Dashboard' -export { default as Search } from './Search' +export { default as Search } from './search' export { default as LoginPrompt } from './auth/LoginPrompt' export { default as AuthError } from './auth/AuthError' \ No newline at end of file diff --git a/frontend/src/components/search/DynamicFilters.tsx b/frontend/src/components/search/DynamicFilters.tsx new file mode 100644 index 0000000..19fb616 --- /dev/null +++ b/frontend/src/components/search/DynamicFilters.tsx @@ -0,0 +1,155 @@ +import { FilterInput } from './FilterInput' +import { CollectionMetadata, SchemaArg } from '@/types/mcp' + +// Pure helper functions for schema processing +const formatFieldLabel = (field: string): string => + field.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase()) + +const shouldSkipField = (fieldName: string): boolean => + ['tags'].includes(fieldName) + +const isCommonField = (fieldName: string): boolean => + ['size', 'filename', 'content_type', 'mail_message_id', 'sent_at', 'created_at', 'source_id'].includes(fieldName) + +const createMinMaxFields = (fieldName: string, fieldConfig: any): [string, any][] => [ + [`min_${fieldName}`, { ...fieldConfig, description: `Min ${fieldConfig.description}` }], + [`max_${fieldName}`, { ...fieldConfig, description: `Max ${fieldConfig.description}` }] +] + +const createSizeFields = (): [string, any][] => [ + ['min_size', { type: 'int', description: 'Minimum size in bytes' }], + ['max_size', { type: 'int', description: 'Maximum size in bytes' }] +] + +const extractSchemaFields = (schema: Record, includeCommon = true): [string, SchemaArg][] => + Object.entries(schema) + .filter(([fieldName]) => !shouldSkipField(fieldName)) + .filter(([fieldName]) => includeCommon || !isCommonField(fieldName)) + .flatMap(([fieldName, fieldConfig]) => + ['sent_at', 'published', 'created_at'].includes(fieldName) + ? createMinMaxFields(fieldName, fieldConfig) + : [[fieldName, fieldConfig] as [string, SchemaArg]] + ) + +const getCommonFields = (schemas: Record, selectedModalities: string[]): [string, SchemaArg][] => { + const commonFieldsMap = new Map() + + // Manually add created_at fields even if not in schema + createMinMaxFields('created_at', { type: 'datetime', description: 'Creation date' }).forEach(([field, config]) => { + commonFieldsMap.set(field, config) + }) + + selectedModalities.forEach(modality => { + const schema = schemas[modality].schema + if (!schema) return + + Object.entries(schema).forEach(([fieldName, fieldConfig]) => { + if (isCommonField(fieldName)) { + if (['sent_at', 'created_at'].includes(fieldName)) { + createMinMaxFields(fieldName, fieldConfig).forEach(([field, config]) => { + commonFieldsMap.set(field, config) + }) + } else if (fieldName === 'size') { + createSizeFields().forEach(([field, config]) => { + commonFieldsMap.set(field, config) + }) + } else { + commonFieldsMap.set(fieldName, fieldConfig) + } + } + }) + }) + + return Array.from(commonFieldsMap.entries()) +} + +const getModalityFields = (schemas: Record, selectedModalities: string[]): Record => { + return selectedModalities.reduce((acc, modality) => { + const schema = schemas[modality].schema + if (!schema) return acc + + const schemaFields = extractSchemaFields(schema, false) // Exclude common fields + + if (schemaFields.length > 0) { + acc[modality] = schemaFields + } + + return acc + }, {} as Record) +} + +interface DynamicFiltersProps { + schemas: Record + selectedModalities: string[] + filters: Record + onFilterChange: (field: string, value: SchemaArg) => void +} + +export const DynamicFilters = ({ + schemas, + selectedModalities, + filters, + onFilterChange +}: DynamicFiltersProps) => { + const commonFields = getCommonFields(schemas, selectedModalities) + const modalityFields = getModalityFields(schemas, selectedModalities) + + if (commonFields.length === 0 && Object.keys(modalityFields).length === 0) { + return null + } + + return ( +
+ +
+ {/* Common/Document Properties Section */} + {commonFields.length > 0 && ( +
+ + Document Properties + +
+ {commonFields.map(([field, fieldConfig]: [string, SchemaArg]) => ( +
+ + +
+ ))} +
+
+ )} + + {/* Modality-specific sections */} + {Object.entries(modalityFields).map(([modality, fields]) => ( +
+ + {formatFieldLabel(modality)} Specific + +
+ {fields.map(([field, fieldConfig]: [string, SchemaArg]) => ( +
+ + +
+ ))} +
+
+ ))} +
+
+ ) +} \ No newline at end of file diff --git a/frontend/src/components/search/FilterInput.tsx b/frontend/src/components/search/FilterInput.tsx new file mode 100644 index 0000000..0fb48d4 --- /dev/null +++ b/frontend/src/components/search/FilterInput.tsx @@ -0,0 +1,53 @@ +// Pure helper functions +const isDateField = (field: string): boolean => + field.includes('sent_at') || field.includes('published') || field.includes('created_at') + +const isNumberField = (field: string, fieldConfig: any): boolean => + fieldConfig.type?.includes('int') || field.includes('size') + +const parseNumberValue = (value: string): number | null => + value ? parseInt(value) : null + +interface FilterInputProps { + field: string + fieldConfig: any + value: any + onChange: (field: string, value: any) => void +} + +export const FilterInput = ({ field, fieldConfig, value, onChange }: FilterInputProps) => { + const inputProps = { + value: value || '', + className: "filter-input" + } + + if (isNumberField(field, fieldConfig)) { + return ( + onChange(field, parseNumberValue(e.target.value))} + placeholder={fieldConfig.description} + /> + ) + } + + if (isDateField(field)) { + return ( + onChange(field, e.target.value || null)} + /> + ) + } + + return ( + onChange(field, e.target.value || null)} + placeholder={fieldConfig.description} + /> + ) +} \ No newline at end of file diff --git a/frontend/src/components/search/Search.tsx b/frontend/src/components/search/Search.tsx new file mode 100644 index 0000000..02302ca --- /dev/null +++ b/frontend/src/components/search/Search.tsx @@ -0,0 +1,69 @@ +import { useState } from 'react' +import { useNavigate } from 'react-router-dom' +import { useMCP } from '@/hooks/useMCP' +import Loading from '@/components/Loading' +import { SearchResult } from './results' +import SearchForm, { SearchParams } from './SearchForm' + +const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => { + if (isLoading) { + return + } + return ( +
+ {results.length > 0 && ( +
+ Found {results.length} result{results.length !== 1 ? 's' : ''} +
+ )} + + {results.map((result, index) => )} + + {results.length === 0 && ( +
+ No results found +
+ )} +
+ ) +} + + +const Search = () => { + const navigate = useNavigate() + const [results, setResults] = useState([]) + const [isLoading, setIsLoading] = useState(false) + const { searchKnowledgeBase } = useMCP() + + const handleSearch = async (params: SearchParams) => { + if (!params.query.trim()) return + + setIsLoading(true) + try { + console.log(params) + const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities) + setResults(searchResults || []) + } catch (error) { + console.error('Search error:', error) + setResults([]) + } finally { + setIsLoading(false) + } + } + + return ( +
+
+ +

🔍 Search Knowledge Base

+
+ + + +
+ ) +} + +export default Search \ No newline at end of file diff --git a/frontend/src/components/search/SearchForm.tsx b/frontend/src/components/search/SearchForm.tsx new file mode 100644 index 0000000..717fa78 --- /dev/null +++ b/frontend/src/components/search/SearchForm.tsx @@ -0,0 +1,151 @@ +import { useMCP } from '@/hooks/useMCP' +import { useEffect, useState } from 'react' +import { DynamicFilters } from './DynamicFilters' +import { SelectableTags } from './SelectableTags' +import { CollectionMetadata } from '@/types/mcp' + +type Filter = { + tags?: string[] + source_ids?: string[] + [key: string]: any +} + +export interface SearchParams { + query: string + previews: boolean + modalities: string[] + filters: Filter + limit: number +} + +interface SearchFormProps { + isLoading: boolean + onSearch: (params: SearchParams) => void +} + + + +// Pure helper functions for SearchForm +const createFlags = (items: string[], defaultValue = false): Record => + items.reduce((acc, item) => ({ ...acc, [item]: defaultValue }), {}) + +const getSelectedItems = (items: Record): string[] => + Object.entries(items).filter(([_, selected]) => selected).map(([key]) => key) + +const cleanFilters = (filters: Record): Record => + Object.entries(filters) + .filter(([_, value]) => value !== null && value !== '' && value !== undefined) + .reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {}) + +export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => { + const [query, setQuery] = useState('') + const [previews, setPreviews] = useState(false) + const [modalities, setModalities] = useState>({}) + const [schemas, setSchemas] = useState>({}) + const [tags, setTags] = useState>({}) + const [dynamicFilters, setDynamicFilters] = useState>({}) + const [limit, setLimit] = useState(10) + const { getMetadataSchemas, getTags } = useMCP() + + useEffect(() => { + const setupFilters = async () => { + const [schemas, tags] = await Promise.all([ + getMetadataSchemas(), + getTags() + ]) + setSchemas(schemas) + setModalities(createFlags(Object.keys(schemas), true)) + setTags(createFlags(tags)) + } + setupFilters() + }, [getMetadataSchemas, getTags]) + + const handleFilterChange = (field: string, value: any) => + setDynamicFilters(prev => ({ ...prev, [field]: value })) + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault() + + onSearch({ + query, + previews, + modalities: getSelectedItems(modalities), + filters: { + tags: getSelectedItems(tags), + ...cleanFilters(dynamicFilters) + }, + limit + }) + } + + return ( +
+
+ setQuery(e.target.value)} + placeholder="Search your knowledge base..." + className="search-input" + required + /> + +
+ +
+
+ +
+ + setModalities({ ...modalities, [tag]: selected })} + onBatchUpdate={(updates) => setModalities(updates)} + /> + + setTags({ ...tags, [tag]: selected })} + onBatchUpdate={(updates) => setTags(updates)} + searchable={true} + /> + + + +
+ +
+
+
+ ) +} + +export default SearchForm \ No newline at end of file diff --git a/frontend/src/components/search/SelectableTags.tsx b/frontend/src/components/search/SelectableTags.tsx new file mode 100644 index 0000000..dddcddc --- /dev/null +++ b/frontend/src/components/search/SelectableTags.tsx @@ -0,0 +1,136 @@ +interface SelectableTagProps { + tag: string + selected: boolean + onSelect: (tag: string, selected: boolean) => void +} + +const SelectableTag = ({ tag, selected, onSelect }: SelectableTagProps) => { + return ( + onSelect(tag, !selected)} + > + {tag} + + ) +} + +import { useState } from 'react' + +interface SelectableTagsProps { + title: string + className: string + tags: Record + onSelect: (tag: string, selected: boolean) => void + onBatchUpdate?: (updates: Record) => void + searchable?: boolean +} + +export const SelectableTags = ({ title, className, tags, onSelect, onBatchUpdate, searchable = false }: SelectableTagsProps) => { + const [searchTerm, setSearchTerm] = useState('') + const handleSelectAll = () => { + if (onBatchUpdate) { + const updates = Object.keys(tags).reduce((acc, tag) => { + acc[tag] = true + return acc + }, {} as Record) + onBatchUpdate(updates) + } else { + // Fallback to individual updates (though this won't work well) + Object.keys(tags).forEach(tag => { + if (!tags[tag]) { + onSelect(tag, true) + } + }) + } + } + + const handleDeselectAll = () => { + if (onBatchUpdate) { + const updates = Object.keys(tags).reduce((acc, tag) => { + acc[tag] = false + return acc + }, {} as Record) + onBatchUpdate(updates) + } else { + // Fallback to individual updates (though this won't work well) + Object.keys(tags).forEach(tag => { + if (tags[tag]) { + onSelect(tag, false) + } + }) + } + } + + // Filter tags based on search term + const filteredTags = Object.entries(tags).filter(([tag, selected]) => { + return !searchTerm || tag.toLowerCase().includes(searchTerm.toLowerCase()) + }) + + const selectedCount = Object.values(tags).filter(Boolean).length + const totalCount = Object.keys(tags).length + const filteredSelectedCount = filteredTags.filter(([_, selected]) => selected).length + const filteredTotalCount = filteredTags.length + const allSelected = selectedCount === totalCount + const noneSelected = selectedCount === 0 + + return ( +
+
+ + {title} ({selectedCount} selected) + + +
+ + + + ({selectedCount}/{totalCount}) + +
+ + {searchable && ( +
+ setSearchTerm(e.target.value)} + className="tag-search-input" + /> + {searchTerm && ( + + Showing {filteredSelectedCount}/{filteredTotalCount} + + )} +
+ )} + +
+ {filteredTags.map(([tag, selected]: [string, boolean]) => ( + + ))} +
+
+
+ ) +} \ No newline at end of file diff --git a/frontend/src/components/search/index.js b/frontend/src/components/search/index.js new file mode 100644 index 0000000..9b16ddd --- /dev/null +++ b/frontend/src/components/search/index.js @@ -0,0 +1,3 @@ +import Search from './Search' + +export default Search \ No newline at end of file diff --git a/frontend/src/components/Search.tsx b/frontend/src/components/search/results.tsx similarity index 58% rename from frontend/src/components/Search.tsx rename to frontend/src/components/search/results.tsx index 13589dc..34b1634 100644 --- a/frontend/src/components/Search.tsx +++ b/frontend/src/components/search/results.tsx @@ -1,10 +1,8 @@ -import React, { useState, useEffect } from 'react' -import { useNavigate } from 'react-router-dom' +import { useState, useEffect } from 'react' import ReactMarkdown from 'react-markdown' -import { useMCP } from '../hooks/useMCP' -import Loading from './Loading' +import { useMCP } from '@/hooks/useMCP' -type SearchItem = { +export type SearchItem = { filename: string content: string chunks: any[] @@ -13,7 +11,7 @@ type SearchItem = { metadata: any } -const Tag = ({ tags }: { tags: string[] }) => { +export const Tag = ({ tags }: { tags: string[] }) => { return (
{tags?.map((tag: string, index: number) => ( @@ -23,11 +21,12 @@ const Tag = ({ tags }: { tags: string[] }) => { ) } -const TextResult = ({ filename, content, chunks, tags }: SearchItem) => { +export const TextResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => { return (

{filename || 'Untitled'}

+

{content || 'No content available'}

{chunks && chunks.length > 0 && (
@@ -44,7 +43,7 @@ const TextResult = ({ filename, content, chunks, tags }: SearchItem) => { ) } -const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => { +export const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => { return (

{filename || 'Untitled'}

@@ -70,7 +69,7 @@ const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchIte ) } -const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => { +export const ImageResult = ({ filename, tags, metadata }: SearchItem) => { const title = metadata?.title || filename || 'Untitled' const { fetchFile } = useMCP() const [mime_type, setMimeType] = useState() @@ -95,7 +94,7 @@ const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => { ) } -const Metadata = ({ metadata }: { metadata: any }) => { +export const Metadata = ({ metadata }: { metadata: any }) => { if (!metadata) return null return (
@@ -108,7 +107,7 @@ const Metadata = ({ metadata }: { metadata: any }) => { ) } -const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => { +export const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => { return (

{filename || 'Untitled'}

@@ -125,7 +124,7 @@ const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => { ) } -const EmailResult = ({ content, tags, metadata }: SearchItem) => { +export const EmailResult = ({ content, tags, metadata }: SearchItem) => { return (

{metadata?.title || metadata?.subject || 'Untitled'}

@@ -138,7 +137,7 @@ const EmailResult = ({ content, tags, metadata }: SearchItem) => { ) } -const SearchResult = ({ result }: { result: SearchItem }) => { +export const SearchResult = ({ result }: { result: SearchItem }) => { if (result.mime_type.startsWith('image/')) { return } @@ -158,86 +157,4 @@ const SearchResult = ({ result }: { result: SearchItem }) => { return null } -const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => { - if (isLoading) { - return - } - return ( -
- {results.length > 0 && ( -
- Found {results.length} result{results.length !== 1 ? 's' : ''} -
- )} - - {results.map((result, index) => )} - - {results.length === 0 && ( -
- No results found -
- )} -
- ) -} - -const SearchForm = ({ isLoading, onSearch }: { isLoading: boolean, onSearch: (query: string) => void }) => { - const [query, setQuery] = useState('') - return ( -
{ - e.preventDefault() - onSearch(query) - }} className="search-form"> -
- setQuery(e.target.value)} - placeholder="Search your knowledge base..." - className="search-input" - /> - -
-
- ) -} - -const Search = () => { - const navigate = useNavigate() - const [results, setResults] = useState([]) - const [isLoading, setIsLoading] = useState(false) - const { searchKnowledgeBase } = useMCP() - - const handleSearch = async (query: string) => { - if (!query.trim()) return - - setIsLoading(true) - try { - const searchResults = await searchKnowledgeBase(query) - setResults(searchResults || []) - } catch (error) { - console.error('Search error:', error) - setResults([]) - } finally { - setIsLoading(false) - } - } - - return ( -
-
- -

🔍 Search Knowledge Base

-
- - - -
- ) -} - -export default Search \ No newline at end of file +export default SearchResult \ No newline at end of file diff --git a/frontend/src/hooks/useAuth.ts b/frontend/src/hooks/useAuth.ts index 697b930..557689f 100644 --- a/frontend/src/hooks/useAuth.ts +++ b/frontend/src/hooks/useAuth.ts @@ -4,20 +4,20 @@ const SERVER_URL = import.meta.env.VITE_SERVER_URL || 'http://localhost:8000' const SESSION_COOKIE_NAME = import.meta.env.VITE_SESSION_COOKIE_NAME || 'session_id' // Cookie utilities -const getCookie = (name) => { +const getCookie = (name: string) => { const value = `; ${document.cookie}` const parts = value.split(`; ${name}=`) if (parts.length === 2) return parts.pop().split(';').shift() return null } -const setCookie = (name, value, days = 30) => { +const setCookie = (name: string, value: string, days = 30) => { const expires = new Date() expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000) document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax` } -const deleteCookie = (name) => { +const deleteCookie = (name: string) => { document.cookie = `${name}=;expires=Thu, 01 Jan 1970 00:00:01 GMT;path=/` } @@ -68,6 +68,7 @@ export const useAuth = () => { deleteCookie('access_token') deleteCookie('refresh_token') deleteCookie(SESSION_COOKIE_NAME) + localStorage.removeItem('oauth_client_id') setIsAuthenticated(false) }, []) @@ -110,7 +111,7 @@ export const useAuth = () => { }, [logout]) // Make authenticated API calls with automatic token refresh - const apiCall = useCallback(async (endpoint, options = {}) => { + const apiCall = useCallback(async (endpoint: string, options: RequestInit = {}) => { let accessToken = getCookie('access_token') if (!accessToken) { @@ -122,7 +123,7 @@ export const useAuth = () => { 'Content-Type': 'application/json', } - const requestOptions = { + const requestOptions: RequestInit & { headers: Record } = { ...options, headers: { ...defaultHeaders, ...options.headers }, } diff --git a/frontend/src/hooks/useMCP.ts b/frontend/src/hooks/useMCP.ts index ca2c150..304cc96 100644 --- a/frontend/src/hooks/useMCP.ts +++ b/frontend/src/hooks/useMCP.ts @@ -1,5 +1,5 @@ import { useEffect, useCallback } from 'react' -import { useAuth } from './useAuth' +import { useAuth } from '@/hooks/useAuth' const parseServerSentEvents = async (response: Response): Promise => { const reader = response.body?.getReader() @@ -91,10 +91,10 @@ const parseJsonRpcResponse = async (response: Response): Promise => { } export const useMCP = () => { - const { apiCall, isAuthenticated, isLoading, checkAuth } = useAuth() + const { apiCall, checkAuth } = useAuth() - const mcpCall = useCallback(async (path: string, method: string, params: any = {}) => { - const response = await apiCall(`/mcp${path}`, { + const mcpCall = useCallback(async (method: string, params: any = {}) => { + const response = await apiCall(`/mcp/${method}`, { method: 'POST', headers: { 'Accept': 'application/json, text/event-stream', @@ -118,22 +118,46 @@ export const useMCP = () => { if (resp?.result?.isError) { throw new Error(resp?.result?.content[0].text) } - return resp?.result?.content.map((item: any) => JSON.parse(item.text)) + return resp?.result?.content.map((item: any) => { + try { + return JSON.parse(item.text) + } catch (e) { + return item.text + } + }) }, [apiCall]) const listNotes = useCallback(async (path: string = "/") => { - return await mcpCall('/note_files', 'note_files', { path }) + return await mcpCall('note_files', { path }) }, [mcpCall]) const fetchFile = useCallback(async (filename: string) => { - return await mcpCall('/fetch_file', 'fetch_file', { filename }) + return await mcpCall('fetch_file', { filename }) }, [mcpCall]) - const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10) => { - return await mcpCall('/search_knowledge_base', 'search_knowledge_base', { + const getTags = useCallback(async () => { + return await mcpCall('get_all_tags') + }, [mcpCall]) + + const getSubjects = useCallback(async () => { + return await mcpCall('get_all_subjects') + }, [mcpCall]) + + const getObservationTypes = useCallback(async () => { + return await mcpCall('get_all_observation_types') + }, [mcpCall]) + + const getMetadataSchemas = useCallback(async () => { + return (await mcpCall('get_metadata_schemas'))[0] + }, [mcpCall]) + + const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record = {}, modalities: string[] = []) => { + return await mcpCall('search_knowledge_base', { query, + filters, + modalities, previews, - limit + limit, }) }, [mcpCall]) @@ -146,5 +170,9 @@ export const useMCP = () => { fetchFile, listNotes, searchKnowledgeBase, + getTags, + getSubjects, + getObservationTypes, + getMetadataSchemas, } } \ No newline at end of file diff --git a/frontend/src/hooks/useOAuth.ts b/frontend/src/hooks/useOAuth.ts index 8f215ff..b5b4491 100644 --- a/frontend/src/hooks/useOAuth.ts +++ b/frontend/src/hooks/useOAuth.ts @@ -14,7 +14,7 @@ const generateCodeVerifier = () => { .replace(/=/g, '') } -const generateCodeChallenge = async (verifier) => { +const generateCodeChallenge = async (verifier: string) => { const data = new TextEncoder().encode(verifier) const digest = await crypto.subtle.digest('SHA-256', data) return btoa(String.fromCharCode(...new Uint8Array(digest))) @@ -33,7 +33,7 @@ const generateState = () => { } // Storage utilities -const setCookie = (name, value, days = 30) => { +const setCookie = (name: string, value: string, days = 30) => { const expires = new Date() expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000) document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax` diff --git a/frontend/src/main.jsx b/frontend/src/main.jsx index 3d9da8a..f9bdd5f 100644 --- a/frontend/src/main.jsx +++ b/frontend/src/main.jsx @@ -1,6 +1,6 @@ import { StrictMode } from 'react' import { createRoot } from 'react-dom/client' -import App from './App.jsx' +import App from '@/App.jsx' createRoot(document.getElementById('root')).render( diff --git a/frontend/src/types/mcp.tsx b/frontend/src/types/mcp.tsx new file mode 100644 index 0000000..8344a5d --- /dev/null +++ b/frontend/src/types/mcp.tsx @@ -0,0 +1,9 @@ +export type CollectionMetadata = { + schema: Record + size: number +} + +export type SchemaArg = { + type: string + description: string +} \ No newline at end of file diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json new file mode 100644 index 0000000..8503bc8 --- /dev/null +++ b/frontend/tsconfig.json @@ -0,0 +1,40 @@ +{ + "compilerOptions": { + "target": "ES2020", + "useDefineForClassFields": true, + "lib": [ + "ES2020", + "DOM", + "DOM.Iterable" + ], + "module": "ESNext", + "skipLibCheck": true, + /* Bundler mode */ + "moduleResolution": "bundler", + "allowImportingTsExtensions": true, + "resolveJsonModule": true, + "isolatedModules": true, + "noEmit": true, + "jsx": "react-jsx", + /* Linting */ + "strict": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "noFallthroughCasesInSwitch": true, + /* Absolute imports */ + "baseUrl": ".", + "paths": { + "@/*": [ + "src/*" + ] + } + }, + "include": [ + "src" + ], + "references": [ + { + "path": "./tsconfig.node.json" + } + ] +} \ No newline at end of file diff --git a/frontend/tsconfig.node.json b/frontend/tsconfig.node.json new file mode 100644 index 0000000..951e593 --- /dev/null +++ b/frontend/tsconfig.node.json @@ -0,0 +1,12 @@ +{ + "compilerOptions": { + "composite": true, + "skipLibCheck": true, + "module": "ESNext", + "moduleResolution": "bundler", + "allowSyntheticDefaultImports": true + }, + "include": [ + "vite.config.js" + ] +} \ No newline at end of file diff --git a/frontend/vite.config.js b/frontend/vite.config.js index a7e9c40..51ed537 100644 --- a/frontend/vite.config.js +++ b/frontend/vite.config.js @@ -1,8 +1,14 @@ import { defineConfig } from 'vite' import react from '@vitejs/plugin-react' +import path from 'path' // https://vite.dev/config/ export default defineConfig({ plugins: [react()], base: '/ui/', + resolve: { + alias: { + '@': path.resolve(__dirname, './src'), + }, + }, }) diff --git a/src/memory/api/MCP/__init__.py b/src/memory/api/MCP/__init__.py index 5ee8067..8711b5d 100644 --- a/src/memory/api/MCP/__init__.py +++ b/src/memory/api/MCP/__init__.py @@ -1,2 +1,3 @@ import memory.api.MCP.manifest import memory.api.MCP.memory +import memory.api.MCP.metadata diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py index 4710afb..440d65f 100644 --- a/src/memory/api/MCP/base.py +++ b/src/memory/api/MCP/base.py @@ -127,7 +127,6 @@ async def handle_login(request: Request): return login_form(request, oauth_params, "Invalid email or password") redirect_url = await oauth_provider.complete_authorization(oauth_params, user) - print("redirect_url", redirect_url) if redirect_url.startswith("http://anysphere.cursor-retrieval"): redirect_url = redirect_url.replace("http://", "cursor://") return RedirectResponse(url=redirect_url, status_code=302) diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/memory.py index 89aaf30..6d63519 100644 --- a/src/memory/api/MCP/memory.py +++ b/src/memory/api/MCP/memory.py @@ -3,23 +3,25 @@ MCP tools for the epistemic sparring partner system. """ import logging +import mimetypes import pathlib from datetime import datetime, timezone -import mimetypes +from typing import Any from pydantic import BaseModel -from sqlalchemy import Text, func +from sqlalchemy import Text from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY +from memory.api.MCP.tools import mcp from memory.api.search.search import SearchFilters, search from memory.common import extract, settings +from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION +from memory.common.celery_app import app as celery_app from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS from memory.common.db.connection import make_session -from memory.common.db.models import AgentObservation, SourceItem +from memory.common.db.models import SourceItem, AgentObservation from memory.common.formatters import observation -from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE -from memory.api.MCP.tools import mcp logger = logging.getLogger(__name__) @@ -47,11 +49,13 @@ def filter_observation_source_ids( return source_ids -def filter_source_ids( - modalities: set[str], - tags: list[str] | None = None, -): - if not tags: +def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] | None: + if source_ids := filters.get("source_ids"): + return source_ids + + tags = filters.get("tags") + size = filters.get("size") + if not (tags or size): return None with make_session() as session: @@ -62,6 +66,8 @@ def filter_source_ids( items_query = items_query.filter( SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))), ) + if size: + items_query = items_query.filter(SourceItem.size == size) if modalities: items_query = items_query.filter(SourceItem.modality.in_(modalities)) source_ids = [item.id for item in items_query.all()] @@ -69,51 +75,12 @@ def filter_source_ids( return source_ids -@mcp.tool() -async def get_all_tags() -> list[str]: - """ - Get all unique tags used across the entire knowledge base. - Returns sorted list of tags from both observations and content. - """ - with make_session() as session: - tags_query = session.query(func.unnest(SourceItem.tags)).distinct() - return sorted({row[0] for row in tags_query if row[0] is not None}) - - -@mcp.tool() -async def get_all_subjects() -> list[str]: - """ - Get all unique subjects from observations about the user. - Returns sorted list of subject identifiers used in observations. - """ - with make_session() as session: - return sorted( - r.subject for r in session.query(AgentObservation.subject).distinct() - ) - - -@mcp.tool() -async def get_all_observation_types() -> list[str]: - """ - Get all observation types that have been used. - Standard types are belief, preference, behavior, contradiction, general, but there can be more. - """ - with make_session() as session: - return sorted( - { - r.observation_type - for r in session.query(AgentObservation.observation_type).distinct() - if r.observation_type is not None - } - ) - - @mcp.tool() async def search_knowledge_base( query: str, - previews: bool = False, + filters: dict[str, Any], modalities: set[str] = set(), - tags: list[str] = [], + previews: bool = False, limit: int = 10, ) -> list[dict]: """ @@ -125,7 +92,7 @@ async def search_knowledge_base( query: Natural language search query - be descriptive about what you're looking for previews: Include actual content in results - when false only a snippet is returned modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all) - tags: Filter by tags - content must have at least one matching tag + filters: Filter by tags, source_ids, etc. limit: Max results (1-100) Returns: List of search results with id, score, chunks, content, filename @@ -137,6 +104,9 @@ async def search_knowledge_base( modalities = set(ALL_COLLECTIONS.keys()) modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS + search_filters = SearchFilters(**filters) + search_filters["source_ids"] = filter_source_ids(modalities, search_filters) + upload_data = extract.extract_text(query) results = await search( upload_data, @@ -145,10 +115,7 @@ async def search_knowledge_base( limit=limit, min_text_score=0.4, min_multimodal_score=0.25, - filters=SearchFilters( - tags=tags, - source_ids=filter_source_ids(tags=tags, modalities=modalities), - ), + filters=search_filters, ) return [result.model_dump() for result in results] diff --git a/src/memory/api/MCP/metadata.py b/src/memory/api/MCP/metadata.py new file mode 100644 index 0000000..ce54bec --- /dev/null +++ b/src/memory/api/MCP/metadata.py @@ -0,0 +1,119 @@ +import logging +from collections import defaultdict +from typing import Annotated, TypedDict, get_args, get_type_hints + +from memory.common import qdrant +from sqlalchemy import func + +from memory.api.MCP.tools import mcp +from memory.common.db.connection import make_session +from memory.common.db.models import SourceItem +from memory.common.db.models.source_items import AgentObservation + +logger = logging.getLogger(__name__) + + +class SchemaArg(TypedDict): + type: str | None + description: str | None + + +class CollectionMetadata(TypedDict): + schema: dict[str, SchemaArg] + size: int + + +def from_annotation(annotation: Annotated) -> SchemaArg | None: + try: + type_, description = get_args(annotation) + type_str = str(type_) + if type_str.startswith("typing."): + type_str = type_str[7:] + elif len((parts := type_str.split("'"))) > 1: + type_str = parts[1] + return SchemaArg(type=type_str, description=description) + except IndexError: + logger.error(f"Error from annotation: {annotation}") + return None + + +def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]: + if not hasattr(klass, "as_payload"): + return {} + + if not (payload_type := get_type_hints(klass.as_payload).get("return")): + return {} + + return { + name: schema + for name, arg in payload_type.__annotations__.items() + if (schema := from_annotation(arg)) + } + + +@mcp.tool() +async def get_metadata_schemas() -> dict[str, CollectionMetadata]: + """Get the metadata schema for each collection used in the knowledge base. + + These schemas can be used to filter the knowledge base. + + Returns: A mapping of collection names to their metadata schemas with field types and descriptions. + + Example: + ``` + { + "mail": {"subject": {"type": "str", "description": "The subject of the email."}}, + "chat": {"subject": {"type": "str", "description": "The subject of the chat message."}} + } + """ + client = qdrant.get_qdrant_client() + sizes = qdrant.get_collection_sizes(client) + schemas = defaultdict(dict) + for klass in SourceItem.__subclasses__(): + for collection in klass.get_collections(): + schemas[collection].update(get_schema(klass)) + + return { + collection: CollectionMetadata(schema=schema, size=size) + for collection, schema in schemas.items() + if (size := sizes.get(collection)) + } + + +@mcp.tool() +async def get_all_tags() -> list[str]: + """Get all unique tags used across the entire knowledge base. + + Returns sorted list of tags from both observations and content. + """ + with make_session() as session: + tags_query = session.query(func.unnest(SourceItem.tags)).distinct() + return sorted({row[0] for row in tags_query if row[0] is not None}) + + +@mcp.tool() +async def get_all_subjects() -> list[str]: + """Get all unique subjects from observations about the user. + + Returns sorted list of subject identifiers used in observations. + """ + with make_session() as session: + return sorted( + r.subject for r in session.query(AgentObservation.subject).distinct() + ) + + +@mcp.tool() +async def get_all_observation_types() -> list[str]: + """Get all observation types that have been used. + + Standard types are belief, preference, behavior, contradiction, general, but there can be more. + """ + with make_session() as session: + return sorted( + { + r.observation_type + for r in session.query(AgentObservation.observation_type).distinct() + if r.observation_type is not None + } + ) diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 57ffbd6..2f1974c 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -71,6 +71,13 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]: # SQLAdmin setup with OAuth protection engine = get_engine() admin = Admin(app, engine) +admin.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # [settings.SERVER_URL], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) # Setup admin with OAuth protection using existing OAuth provider setup_admin(admin) diff --git a/src/memory/api/search/embeddings.py b/src/memory/api/search/embeddings.py index 6af1d06..c1bfcfe 100644 --- a/src/memory/api/search/embeddings.py +++ b/src/memory/api/search/embeddings.py @@ -1,7 +1,7 @@ import base64 import io import logging -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, cast import qdrant_client from PIL import Image @@ -90,13 +90,71 @@ def query_chunks( } +def merge_range_filter( + filters: list[dict[str, Any]], key: str, val: Any +) -> list[dict[str, Any]]: + direction, field = key.split("_", maxsplit=1) + item = next((f for f in filters if f["key"] == field), None) + if not item: + item = {"key": field, "range": {}} + filters.append(item) + + if direction == "min": + item["range"]["gte"] = val + elif direction == "max": + item["range"]["lte"] = val + return filters + + +def merge_filters( + filters: list[dict[str, Any]], key: str, val: Any +) -> list[dict[str, Any]]: + if not val and val != 0: + return filters + + list_filters = ["tags", "recipients", "observation_types", "authors"] + range_filters = [ + "min_sent_at", + "max_sent_at", + "min_published", + "max_published", + "min_size", + "max_size", + "min_created_at", + "max_created_at", + ] + if key in list_filters: + filters.append({"key": key, "match": {"any": val}}) + + elif key in range_filters: + return merge_range_filter(filters, key, val) + + elif key == "min_confidences": + confidence_filters = [ + { + "key": f"confidence.{confidence_type}", + "range": {"gte": min_confidence_score}, + } + for confidence_type, min_confidence_score in cast(dict, val).items() + ] + filters.extend(confidence_filters) + + elif key == "source_ids": + filters.append({"key": "id", "match": {"any": val}}) + + else: + filters.append({"key": key, "match": val}) + + return filters + + async def search_embeddings( data: list[extract.DataChunk], previews: Optional[bool] = False, modalities: set[str] = set(), limit: int = 10, min_score: float = 0.3, - filters: SearchFilters = SearchFilters(), + filters: SearchFilters = {}, multimodal: bool = False, ) -> list[tuple[SourceData, AnnotatedChunk]]: """ @@ -111,27 +169,11 @@ async def search_embeddings( - filters: Filters to apply to the search results - multimodal: Whether to search in multimodal collections """ - query_filters = {"must": []} - - # Handle structured confidence filtering - if min_confidences := filters.get("min_confidences"): - confidence_filters = [ - { - "key": f"confidence.{confidence_type}", - "range": {"gte": min_confidence_score}, - } - for confidence_type, min_confidence_score in min_confidences.items() - ] - query_filters["must"].extend(confidence_filters) - - if tags := filters.get("tags"): - query_filters["must"].append({"key": "tags", "match": {"any": tags}}) - - if observation_types := filters.get("observation_types"): - query_filters["must"].append( - {"key": "observation_type", "match": {"any": observation_types}} - ) + search_filters = [] + for key, val in filters.items(): + search_filters = merge_filters(search_filters, key, val) + print(search_filters) client = qdrant.get_qdrant_client() results = query_chunks( client, @@ -140,7 +182,7 @@ async def search_embeddings( embedding.embed_text if not multimodal else embedding.embed_mixed, min_score=min_score, limit=limit, - filters=query_filters if query_filters["must"] else None, + filters={"must": search_filters} if search_filters else None, ) search_results = {k: results.get(k, []) for k in modalities} diff --git a/src/memory/api/search/search.py b/src/memory/api/search/search.py index 2e07d5d..610d59d 100644 --- a/src/memory/api/search/search.py +++ b/src/memory/api/search/search.py @@ -17,6 +17,7 @@ from memory.common.collections import ( MULTIMODAL_COLLECTIONS, TEXT_COLLECTIONS, ) +from memory.common import settings logger = logging.getLogger(__name__) @@ -44,51 +45,57 @@ async def search( - List of search results sorted by score """ allowed_modalities = modalities & ALL_COLLECTIONS.keys() + print(allowed_modalities) - text_embeddings_results = with_timeout( - search_embeddings( - data, - previews, - allowed_modalities & TEXT_COLLECTIONS, - limit, - min_text_score, - filters, - multimodal=False, - ), - timeout, - ) - multimodal_embeddings_results = with_timeout( - search_embeddings( - data, - previews, - allowed_modalities & MULTIMODAL_COLLECTIONS, - limit, - min_multimodal_score, - filters, - multimodal=True, - ), - timeout, - ) - bm25_results = with_timeout( - search_bm25( - " ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]), - modalities, - limit=limit, - filters=filters, - ), - timeout, - ) + searches = [] + if settings.ENABLE_EMBEDDING_SEARCH: + searches = [ + with_timeout( + search_embeddings( + data, + previews, + allowed_modalities & TEXT_COLLECTIONS, + limit, + min_text_score, + filters, + multimodal=False, + ), + timeout, + ), + with_timeout( + search_embeddings( + data, + previews, + allowed_modalities & MULTIMODAL_COLLECTIONS, + limit, + min_multimodal_score, + filters, + multimodal=True, + ), + timeout, + ), + ] + if settings.ENABLE_BM25_SEARCH: + searches.append( + with_timeout( + search_bm25( + " ".join( + [c for chunk in data for c in chunk.data if isinstance(c, str)] + ), + modalities, + limit=limit, + filters=filters, + ), + timeout, + ) + ) - results = await asyncio.gather( - text_embeddings_results, - multimodal_embeddings_results, - bm25_results, - return_exceptions=False, - ) - text_results, multi_results, bm25_results = results - all_results = text_results + multi_results - if len(all_results) < limit: - all_results += bm25_results + search_results = await asyncio.gather(*searches, return_exceptions=False) + all_results = [] + for results in search_results: + if len(all_results) >= limit: + break + all_results.extend(results) results = group_chunks(all_results, previews or False) return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True) diff --git a/src/memory/api/search/utils.py b/src/memory/api/search/utils.py index 4a9da22..1833235 100644 --- a/src/memory/api/search/utils.py +++ b/src/memory/api/search/utils.py @@ -65,9 +65,9 @@ class SearchResult(BaseModel): class SearchFilters(TypedDict): - subject: NotRequired[str | None] + min_size: NotRequired[int] + max_size: NotRequired[int] min_confidences: NotRequired[dict[str, float]] - tags: NotRequired[list[str] | None] observation_types: NotRequired[list[str] | None] source_ids: NotRequired[list[int] | None] @@ -115,7 +115,6 @@ def group_chunks( if isinstance(contents, dict): tags = contents.pop("tags", []) content = contents.pop("content", None) - print(content) else: content = contents contents = {} diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index 8888a8f..62ced5e 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -4,6 +4,7 @@ from memory.common.db.models.source_item import ( SourceItem, ConfidenceScore, clean_filename, + SourceItemPayload, ) from memory.common.db.models.source_items import ( MailMessage, @@ -19,6 +20,14 @@ from memory.common.db.models.source_items import ( Photo, MiscDoc, Note, + MailMessagePayload, + EmailAttachmentPayload, + AgentObservationPayload, + BlogPostPayload, + ComicPayload, + BookSectionPayload, + NotePayload, + ForumPostPayload, ) from memory.common.db.models.observations import ( ObservationContradiction, @@ -40,6 +49,18 @@ from memory.common.db.models.users import ( OAuthRefreshToken, ) +Payload = ( + SourceItemPayload + | AgentObservationPayload + | NotePayload + | BlogPostPayload + | ComicPayload + | BookSectionPayload + | ForumPostPayload + | EmailAttachmentPayload + | MailMessagePayload +) + __all__ = [ "Base", "Chunk", @@ -75,4 +96,6 @@ __all__ = [ "OAuthClientInformation", "OAuthState", "OAuthRefreshToken", + # Payloads + "Payload", ] diff --git a/src/memory/common/db/models/source_item.py b/src/memory/common/db/models/source_item.py index 8ec4a43..d6e68a5 100644 --- a/src/memory/common/db/models/source_item.py +++ b/src/memory/common/db/models/source_item.py @@ -4,7 +4,7 @@ Database models for the knowledge base system. import pathlib import re -from typing import Any, Sequence, cast +from typing import Any, Annotated, Sequence, TypedDict, cast import uuid from PIL import Image @@ -36,6 +36,17 @@ import memory.common.summarizer as summarizer from memory.common.db.models.base import Base +class MetadataSchema(TypedDict): + type: str + description: str + + +class SourceItemPayload(TypedDict): + source_id: Annotated[int, "Unique identifier of the source item"] + tags: Annotated[list[str], "List of tags for categorization"] + size: Annotated[int | None, "Size of the content in bytes"] + + @event.listens_for(Session, "before_flush") def handle_duplicate_sha256(session, flush_context, instances): """ @@ -344,12 +355,17 @@ class SourceItem(Base): def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: return [self._make_chunk(data, metadata) for data in self._chunk_contents()] - def as_payload(self) -> dict: - return { - "source_id": self.id, - "tags": self.tags, - "size": self.size, - } + def as_payload(self) -> SourceItemPayload: + return SourceItemPayload( + source_id=cast(int, self.id), + tags=cast(list[str], self.tags), + size=cast(int | None, self.size), + ) + + @classmethod + def get_collections(cls) -> list[str]: + """Return the list of Qdrant collections this SourceItem type can be stored in.""" + return [cls.__tablename__] @property def display_contents(self) -> str | dict | None: diff --git a/src/memory/common/db/models/source_items.py b/src/memory/common/db/models/source_items.py index b391b6f..1c9f313 100644 --- a/src/memory/common/db/models/source_items.py +++ b/src/memory/common/db/models/source_items.py @@ -5,7 +5,7 @@ Database models for the knowledge base system. import pathlib import textwrap from datetime import datetime -from typing import Any, Sequence, cast +from typing import Any, Annotated, Sequence, cast from PIL import Image from sqlalchemy import ( @@ -32,11 +32,21 @@ import memory.common.formatters.observation as observation from memory.common.db.models.source_item import ( SourceItem, Chunk, + SourceItemPayload, clean_filename, chunk_mixed, ) +class MailMessagePayload(SourceItemPayload): + message_id: Annotated[str, "Unique email message identifier"] + subject: Annotated[str, "Email subject line"] + sender: Annotated[str, "Email sender address"] + recipients: Annotated[list[str], "List of recipient email addresses"] + folder: Annotated[str, "Email folder name"] + date: Annotated[str | None, "Email sent date in ISO format"] + + class MailMessage(SourceItem): __tablename__ = "mail_message" @@ -80,17 +90,21 @@ class MailMessage(SourceItem): path.parent.mkdir(parents=True, exist_ok=True) return path - def as_payload(self) -> dict: - return { - **super().as_payload(), - "message_id": self.message_id, - "subject": self.subject, - "sender": self.sender, - "recipients": self.recipients, - "folder": self.folder, - "tags": self.tags + [self.sender] + self.recipients, - "date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore + def as_payload(self) -> MailMessagePayload: + base_payload = super().as_payload() | { + "tags": cast(list[str], self.tags) + + [cast(str, self.sender)] + + cast(list[str], self.recipients) } + return MailMessagePayload( + **cast(dict, base_payload), + message_id=cast(str, self.message_id), + subject=cast(str, self.subject), + sender=cast(str, self.sender), + recipients=cast(list[str], self.recipients), + folder=cast(str, self.folder), + date=(self.sent_at and self.sent_at.isoformat() or None), # type: ignore + ) @property def parsed_content(self) -> dict[str, Any]: @@ -152,7 +166,7 @@ class MailMessage(SourceItem): def _chunk_contents(self) -> Sequence[extract.DataChunk]: content = self.parsed_content - chunks = extract.extract_text(cast(str, self.body)) + chunks = extract.extract_text(cast(str, self.body), modality="email") def add_header(item: extract.MulitmodalChunk) -> extract.MulitmodalChunk: if isinstance(item, str): @@ -163,6 +177,10 @@ class MailMessage(SourceItem): chunk.data = [add_header(item) for item in chunk.data] return chunks + @classmethod + def get_collections(cls) -> list[str]: + return ["mail"] + # Add indexes __table_args__ = ( Index("mail_sent_idx", "sent_at"), @@ -171,6 +189,13 @@ class MailMessage(SourceItem): ) +class EmailAttachmentPayload(SourceItemPayload): + filename: Annotated[str, "Name of the document file"] + content_type: Annotated[str, "MIME type of the document"] + mail_message_id: Annotated[int, "Associated email message ID"] + sent_at: Annotated[str | None, "Document creation timestamp"] + + class EmailAttachment(SourceItem): __tablename__ = "email_attachment" @@ -190,17 +215,20 @@ class EmailAttachment(SourceItem): "polymorphic_identity": "email_attachment", } - def as_payload(self) -> dict: - return { + def as_payload(self) -> EmailAttachmentPayload: + return EmailAttachmentPayload( **super().as_payload(), - "filename": self.filename, - "content_type": self.mime_type, - "size": self.size, - "created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore - "mail_message_id": self.mail_message_id, - } + filename=cast(str, self.filename), + content_type=cast(str, self.mime_type), + mail_message_id=cast(int, self.mail_message_id), + sent_at=( + self.mail_message.sent_at + and self.mail_message.sent_at.isoformat() + or None + ), # type: ignore + ) - def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: + def _chunk_contents(self) -> Sequence[extract.DataChunk]: if cast(str | None, self.filename): contents = ( settings.FILE_STORAGE_DIR / cast(str, self.filename) @@ -208,8 +236,7 @@ class EmailAttachment(SourceItem): else: contents = cast(str, self.content) - chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents) - return [self._make_chunk(c, metadata) for c in chunks] + return extract.extract_data_chunks(cast(str, self.mime_type), contents) @property def display_contents(self) -> dict: @@ -221,6 +248,11 @@ class EmailAttachment(SourceItem): # Add indexes __table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),) + @classmethod + def get_collections(cls) -> list[str]: + """EmailAttachment can go to different collections based on mime_type""" + return ["doc", "text", "blog", "photo", "book"] + class ChatMessage(SourceItem): __tablename__ = "chat_message" @@ -285,6 +317,16 @@ class Photo(SourceItem): __table_args__ = (Index("photo_taken_idx", "exif_taken_at"),) +class ComicPayload(SourceItemPayload): + title: Annotated[str, "Title of the comic"] + author: Annotated[str | None, "Author of the comic"] + published: Annotated[str | None, "Publication date in ISO format"] + volume: Annotated[str | None, "Volume number"] + issue: Annotated[str | None, "Issue number"] + page: Annotated[int | None, "Page number"] + url: Annotated[str | None, "URL of the comic"] + + class Comic(SourceItem): __tablename__ = "comic" @@ -305,18 +347,17 @@ class Comic(SourceItem): __table_args__ = (Index("comic_author_idx", "author"),) - def as_payload(self) -> dict: - payload = { + def as_payload(self) -> ComicPayload: + return ComicPayload( **super().as_payload(), - "title": self.title, - "author": self.author, - "published": self.published, - "volume": self.volume, - "issue": self.issue, - "page": self.page, - "url": self.url, - } - return {k: v for k, v in payload.items() if v is not None} + title=cast(str, self.title), + author=cast(str | None, self.author), + published=(self.published and self.published.isoformat() or None), # type: ignore + volume=cast(str | None, self.volume), + issue=cast(str | None, self.issue), + page=cast(int | None, self.page), + url=cast(str | None, self.url), + ) def _chunk_contents(self) -> Sequence[extract.DataChunk]: image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename)) @@ -324,6 +365,17 @@ class Comic(SourceItem): return [extract.DataChunk(data=[image, description])] +class BookSectionPayload(SourceItemPayload): + title: Annotated[str, "Title of the book"] + author: Annotated[str | None, "Author of the book"] + book_id: Annotated[int, "Unique identifier of the book"] + section_title: Annotated[str, "Title of the section"] + section_number: Annotated[int, "Number of the section"] + section_level: Annotated[int, "Level of the section"] + start_page: Annotated[int, "Starting page number"] + end_page: Annotated[int, "Ending page number"] + + class BookSection(SourceItem): """Individual sections/chapters of books""" @@ -361,19 +413,22 @@ class BookSection(SourceItem): Index("book_section_level_idx", "section_level", "section_number"), ) - def as_payload(self) -> dict: - vals = { + @classmethod + def get_collections(cls) -> list[str]: + return ["book"] + + def as_payload(self) -> BookSectionPayload: + return BookSectionPayload( **super().as_payload(), - "title": self.book.title, - "author": self.book.author, - "book_id": self.book_id, - "section_title": self.section_title, - "section_number": self.section_number, - "section_level": self.section_level, - "start_page": self.start_page, - "end_page": self.end_page, - } - return {k: v for k, v in vals.items() if v} + title=cast(str, self.book.title), + author=cast(str | None, self.book.author), + book_id=cast(int, self.book_id), + section_title=cast(str, self.section_title), + section_number=cast(int, self.section_number), + section_level=cast(int, self.section_level), + start_page=cast(int, self.start_page), + end_page=cast(int, self.end_page), + ) def _chunk_contents(self) -> Sequence[extract.DataChunk]: content = cast(str, self.content.strip()) @@ -397,6 +452,16 @@ class BookSection(SourceItem): ] +class BlogPostPayload(SourceItemPayload): + url: Annotated[str, "URL of the blog post"] + title: Annotated[str, "Title of the blog post"] + author: Annotated[str | None, "Author of the blog post"] + published: Annotated[str | None, "Publication date in ISO format"] + description: Annotated[str | None, "Description of the blog post"] + domain: Annotated[str | None, "Domain of the blog post"] + word_count: Annotated[int | None, "Word count of the blog post"] + + class BlogPost(SourceItem): __tablename__ = "blog_post" @@ -428,27 +493,39 @@ class BlogPost(SourceItem): Index("blog_post_word_count_idx", "word_count"), ) - def as_payload(self) -> dict: + def as_payload(self) -> BlogPostPayload: published_date = cast(datetime | None, self.published) metadata = cast(dict | None, self.webpage_metadata) or {} - payload = { + return BlogPostPayload( **super().as_payload(), - "url": self.url, - "title": self.title, - "author": self.author, - "published": published_date and published_date.isoformat(), - "description": self.description, - "domain": self.domain, - "word_count": self.word_count, + url=cast(str, self.url), + title=cast(str, self.title), + author=cast(str | None, self.author), + published=(published_date and published_date.isoformat() or None), # type: ignore + description=cast(str | None, self.description), + domain=cast(str | None, self.domain), + word_count=cast(int | None, self.word_count), **metadata, - } - return {k: v for k, v in payload.items() if v} + ) def _chunk_contents(self) -> Sequence[extract.DataChunk]: return chunk_mixed(cast(str, self.content), cast(list[str], self.images)) +class ForumPostPayload(SourceItemPayload): + url: Annotated[str, "URL of the forum post"] + title: Annotated[str, "Title of the forum post"] + description: Annotated[str | None, "Description of the forum post"] + authors: Annotated[list[str] | None, "Authors of the forum post"] + published: Annotated[str | None, "Publication date in ISO format"] + slug: Annotated[str | None, "Slug of the forum post"] + karma: Annotated[int | None, "Karma score of the forum post"] + votes: Annotated[int | None, "Number of votes on the forum post"] + score: Annotated[int | None, "Score of the forum post"] + comments: Annotated[int | None, "Number of comments on the forum post"] + + class ForumPost(SourceItem): __tablename__ = "forum_post" @@ -479,20 +556,20 @@ class ForumPost(SourceItem): Index("forum_post_title_idx", "title"), ) - def as_payload(self) -> dict: - return { + def as_payload(self) -> ForumPostPayload: + return ForumPostPayload( **super().as_payload(), - "url": self.url, - "title": self.title, - "description": self.description, - "authors": self.authors, - "published_at": self.published_at, - "slug": self.slug, - "karma": self.karma, - "votes": self.votes, - "score": self.score, - "comments": self.comments, - } + url=cast(str, self.url), + title=cast(str, self.title), + description=cast(str | None, self.description), + authors=cast(list[str] | None, self.authors), + published=(self.published_at and self.published_at.isoformat() or None), # type: ignore + slug=cast(str | None, self.slug), + karma=cast(int | None, self.karma), + votes=cast(int | None, self.votes), + score=cast(int | None, self.score), + comments=cast(int | None, self.comments), + ) def _chunk_contents(self) -> Sequence[extract.DataChunk]: return chunk_mixed(cast(str, self.content), cast(list[str], self.images)) @@ -545,6 +622,12 @@ class GithubItem(SourceItem): ) +class NotePayload(SourceItemPayload): + note_type: Annotated[str | None, "Category of the note"] + subject: Annotated[str | None, "What the note is about"] + confidence: Annotated[dict[str, float], "Confidence scores for the note"] + + class Note(SourceItem): """A quick note of something of interest.""" @@ -565,13 +648,13 @@ class Note(SourceItem): Index("note_subject_idx", "subject"), ) - def as_payload(self) -> dict: - return { + def as_payload(self) -> NotePayload: + return NotePayload( **super().as_payload(), - "note_type": self.note_type, - "subject": self.subject, - "confidence": self.confidence_dict, - } + note_type=cast(str | None, self.note_type), + subject=cast(str | None, self.subject), + confidence=self.confidence_dict, + ) @property def display_contents(self) -> dict: @@ -602,6 +685,19 @@ class Note(SourceItem): self.as_text(cast(str, self.content), cast(str | None, self.subject)) ) + @classmethod + def get_collections(cls) -> list[str]: + return ["text"] # Notes go to the text collection + + +class AgentObservationPayload(SourceItemPayload): + session_id: Annotated[str | None, "Session ID for the observation"] + observation_type: Annotated[str, "Type of observation"] + subject: Annotated[str, "What/who the observation is about"] + confidence: Annotated[dict[str, float], "Confidence scores for the observation"] + evidence: Annotated[dict | None, "Supporting context, quotes, etc."] + agent_model: Annotated[str, "Which AI model made this observation"] + class AgentObservation(SourceItem): """ @@ -652,18 +748,16 @@ class AgentObservation(SourceItem): kwargs["modality"] = "observation" super().__init__(**kwargs) - def as_payload(self) -> dict: - payload = { + def as_payload(self) -> AgentObservationPayload: + return AgentObservationPayload( **super().as_payload(), - "observation_type": self.observation_type, - "subject": self.subject, - "confidence": self.confidence_dict, - "evidence": self.evidence, - "agent_model": self.agent_model, - } - if self.session_id is not None: - payload["session_id"] = str(self.session_id) - return payload + observation_type=cast(str, self.observation_type), + subject=cast(str, self.subject), + confidence=self.confidence_dict, + evidence=cast(dict | None, self.evidence), + agent_model=cast(str, self.agent_model), + session_id=cast(str | None, self.session_id) and str(self.session_id), + ) @property def all_contradictions(self): @@ -759,3 +853,7 @@ class AgentObservation(SourceItem): # )) return chunks + + @classmethod + def get_collections(cls) -> list[str]: + return ["semantic", "temporal"] diff --git a/src/memory/common/extract.py b/src/memory/common/extract.py index 3df07df..2fbe50a 100644 --- a/src/memory/common/extract.py +++ b/src/memory/common/extract.py @@ -53,7 +53,9 @@ def page_to_image(page: pymupdf.Page) -> Image.Image: return image -def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]: +def doc_to_images( + content: bytes | str | pathlib.Path, modality: str = "doc" +) -> list[DataChunk]: with as_file(content) as file_path: with pymupdf.open(file_path) as pdf: return [ @@ -65,6 +67,7 @@ def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]: "height": page.rect.height, }, mime_type="image/jpeg", + modality=modality, ) for page in pdf.pages() ] @@ -122,6 +125,7 @@ def extract_text( content: bytes | str | pathlib.Path, chunk_size: int | None = None, metadata: dict[str, Any] = {}, + modality: str = "text", ) -> list[DataChunk]: if isinstance(content, pathlib.Path): content = content.read_text() @@ -130,7 +134,7 @@ def extract_text( content = cast(str, content) chunks = [ - DataChunk(data=[c], modality="text", metadata=metadata) + DataChunk(data=[c], modality=modality, metadata=metadata) for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS) ] if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2: @@ -139,7 +143,7 @@ def extract_text( DataChunk( data=[summary], metadata=merge_metadata(metadata, {"tags": tags}), - modality="text", + modality=modality, ) ) return chunks @@ -158,9 +162,7 @@ def extract_data_chunks( "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/msword", ]: - logger.info(f"Extracting content from {content}") chunks = extract_docx(content) - logger.info(f"Extracted {len(chunks)} pages from {content}") elif mime_type.startswith("text/"): chunks = extract_text(content, chunk_size) elif mime_type.startswith("image/"): diff --git a/src/memory/common/qdrant.py b/src/memory/common/qdrant.py index 5e0121e..5e6ad8f 100644 --- a/src/memory/common/qdrant.py +++ b/src/memory/common/qdrant.py @@ -224,6 +224,15 @@ def get_collection_info( return info.model_dump() +def get_collection_sizes(client: qdrant_client.QdrantClient) -> dict[str, int]: + """Get the size of each collection.""" + collections = [i.name for i in client.get_collections().collections] + return { + collection_name: client.count(collection_name).count # type: ignore + for collection_name in collections + } + + def batch_ids( client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000 ) -> Generator[list[str], None, None]: diff --git a/src/memory/common/settings.py b/src/memory/common/settings.py index 9de24f1..652faf0 100644 --- a/src/memory/common/settings.py +++ b/src/memory/common/settings.py @@ -6,6 +6,8 @@ load_dotenv() def boolean_env(key: str, default: bool = False) -> bool: + if key not in os.environ: + return default return os.getenv(key, "0").lower() in ("1", "true", "yes") @@ -130,6 +132,10 @@ else: ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307") +# Search settings +ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True) +ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True) + # API settings SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000") HTTPS = boolean_env("HTTPS", False) diff --git a/tests/memory/api/search/test_search_embeddings.py b/tests/memory/api/search/test_search_embeddings.py new file mode 100644 index 0000000..ca9a6a8 --- /dev/null +++ b/tests/memory/api/search/test_search_embeddings.py @@ -0,0 +1,174 @@ +import pytest +from memory.api.search.embeddings import merge_range_filter, merge_filters + + +def test_merge_range_filter_new_filter(): + """Test adding new range filters""" + filters = [] + result = merge_range_filter(filters, "min_size", 100) + assert result == [{"key": "size", "range": {"gte": 100}}] + + filters = [] + result = merge_range_filter(filters, "max_size", 1000) + assert result == [{"key": "size", "range": {"lte": 1000}}] + + +def test_merge_range_filter_existing_field(): + """Test adding to existing field""" + filters = [{"key": "size", "range": {"lte": 1000}}] + result = merge_range_filter(filters, "min_size", 100) + assert result == [{"key": "size", "range": {"lte": 1000, "gte": 100}}] + + +def test_merge_range_filter_override_existing(): + """Test overriding existing values""" + filters = [{"key": "size", "range": {"gte": 100}}] + result = merge_range_filter(filters, "min_size", 200) + assert result == [{"key": "size", "range": {"gte": 200}}] + + +def test_merge_range_filter_with_other_filters(): + """Test adding range filter alongside other filters""" + filters = [{"key": "tags", "match": {"any": ["tag1"]}}] + result = merge_range_filter(filters, "min_size", 100) + + expected = [ + {"key": "tags", "match": {"any": ["tag1"]}}, + {"key": "size", "range": {"gte": 100}}, + ] + assert result == expected + + +@pytest.mark.parametrize( + "key,expected_direction,expected_field", + [ + ("min_sent_at", "min", "sent_at"), + ("max_sent_at", "max", "sent_at"), + ("min_published", "min", "published"), + ("max_published", "max", "published"), + ("min_size", "min", "size"), + ("max_size", "max", "size"), + ], +) +def test_merge_range_filter_key_parsing(key, expected_direction, expected_field): + """Test that field names are correctly extracted from keys""" + filters = [] + result = merge_range_filter(filters, key, 100) + + assert len(result) == 1 + assert result[0]["key"] == expected_field + range_key = "gte" if expected_direction == "min" else "lte" + assert result[0]["range"][range_key] == 100 + + +@pytest.mark.parametrize( + "filter_key,filter_value", + [ + ("tags", ["tag1", "tag2"]), + ("recipients", ["user1", "user2"]), + ("observation_types", ["belief", "preference"]), + ("authors", ["author1"]), + ], +) +def test_merge_filters_list_filters(filter_key, filter_value): + """Test list filters that use match any""" + filters = [] + result = merge_filters(filters, filter_key, filter_value) + expected = [{"key": filter_key, "match": {"any": filter_value}}] + assert result == expected + + +def test_merge_filters_min_confidences(): + """Test min_confidences filter creates multiple range conditions""" + filters = [] + confidences = {"observation_accuracy": 0.8, "source_reliability": 0.9} + result = merge_filters(filters, "min_confidences", confidences) + + expected = [ + {"key": "confidence.observation_accuracy", "range": {"gte": 0.8}}, + {"key": "confidence.source_reliability", "range": {"gte": 0.9}}, + ] + assert result == expected + + +def test_merge_filters_source_ids(): + """Test source_ids filter maps to id field""" + filters = [] + result = merge_filters(filters, "source_ids", ["id1", "id2"]) + expected = [{"key": "id", "match": {"any": ["id1", "id2"]}}] + assert result == expected + + +def test_merge_filters_range_delegation(): + """Test range filters are properly delegated to merge_range_filter""" + filters = [] + result = merge_filters(filters, "min_size", 100) + + assert len(result) == 1 + assert "range" in result[0] + assert result[0]["range"]["gte"] == 100 + + +def test_merge_filters_combined_range(): + """Test that min/max range pairs merge into single filter""" + filters = [] + filters = merge_filters(filters, "min_size", 100) + filters = merge_filters(filters, "max_size", 1000) + + size_filters = [f for f in filters if f["key"] == "size"] + assert len(size_filters) == 1 + assert size_filters[0]["range"]["gte"] == 100 + assert size_filters[0]["range"]["lte"] == 1000 + + +def test_merge_filters_preserves_existing(): + """Test that existing filters are preserved when adding new ones""" + existing_filters = [{"key": "existing", "match": "value"}] + result = merge_filters(existing_filters, "tags", ["new_tag"]) + + assert len(result) == 2 + assert {"key": "existing", "match": "value"} in result + assert {"key": "tags", "match": {"any": ["new_tag"]}} in result + + +def test_merge_filters_realistic_combination(): + """Test a realistic filter combination for knowledge base search""" + filters = [] + + # Add typical knowledge base filters + filters = merge_filters(filters, "tags", ["important", "work"]) + filters = merge_filters(filters, "min_published", "2023-01-01") + filters = merge_filters(filters, "max_size", 1000000) # 1MB max + filters = merge_filters(filters, "min_confidences", {"observation_accuracy": 0.8}) + + assert len(filters) == 4 + + # Check each filter type + tag_filter = next(f for f in filters if f["key"] == "tags") + assert tag_filter["match"]["any"] == ["important", "work"] + + published_filter = next(f for f in filters if f["key"] == "published") + assert published_filter["range"]["gte"] == "2023-01-01" + + size_filter = next(f for f in filters if f["key"] == "size") + assert size_filter["range"]["lte"] == 1000000 + + confidence_filter = next( + f for f in filters if f["key"] == "confidence.observation_accuracy" + ) + assert confidence_filter["range"]["gte"] == 0.8 + + +def test_merge_filters_unknown_key(): + """Test fallback behavior for unknown filter keys""" + filters = [] + result = merge_filters(filters, "unknown_field", "unknown_value") + expected = [{"key": "unknown_field", "match": "unknown_value"}] + assert result == expected + + +def test_merge_filters_empty_min_confidences(): + """Test min_confidences with empty dict does nothing""" + filters = [] + result = merge_filters(filters, "min_confidences", {}) + assert result == []