search filters

This commit is contained in:
Daniel O'Connell 2025-06-10 12:16:54 +02:00
parent 780e27ba04
commit 3e4e5872d1
37 changed files with 1632 additions and 354 deletions

View File

@ -285,6 +285,7 @@ body {
display: flex; display: flex;
gap: 1rem; gap: 1rem;
align-items: center; align-items: center;
margin-bottom: 1.5rem;
} }
.search-input { .search-input {
@ -323,6 +324,34 @@ body {
cursor: not-allowed; cursor: not-allowed;
} }
.search-options {
border-top: 1px solid #e2e8f0;
padding-top: 1.5rem;
display: grid;
gap: 1.5rem;
}
.search-option {
display: flex;
flex-direction: column;
gap: 0.5rem;
}
.search-option label {
font-weight: 500;
color: #4a5568;
font-size: 0.9rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.search-option input[type="checkbox"] {
margin-right: 0.5rem;
transform: scale(1.1);
accent-color: #667eea;
}
/* Search Results */ /* Search Results */
.search-results { .search-results {
margin-top: 2rem; margin-top: 2rem;
@ -429,6 +458,11 @@ body {
transition: background-color 0.2s; transition: background-color 0.2s;
} }
.tag.selected {
background: #667eea;
color: white;
}
.tag:hover { .tag:hover {
background: #e2e8f0; background: #e2e8f0;
} }
@ -495,6 +529,14 @@ body {
padding: 1.5rem; padding: 1.5rem;
} }
.search-options {
gap: 1rem;
}
.limit-input {
width: 80px;
}
.search-result-card { .search-result-card {
padding: 1rem; padding: 1rem;
} }
@ -689,4 +731,194 @@ body {
.markdown-preview a:hover { .markdown-preview a:hover {
color: #2b6cb0; color: #2b6cb0;
}
/* Dynamic filters styles */
.modality-filters {
display: flex;
flex-direction: column;
gap: 1rem;
margin-top: 0.5rem;
}
.modality-filter-group {
border: 1px solid #e2e8f0;
border-radius: 8px;
background: white;
}
.modality-filter-title {
padding: 0.75rem 1rem;
font-weight: 600;
color: #4a5568;
background: #f8fafc;
border-radius: 7px 7px 0 0;
cursor: pointer;
user-select: none;
transition: background-color 0.2s;
}
.modality-filter-title:hover {
background: #edf2f7;
}
.modality-filter-group[open] .modality-filter-title {
border-bottom: 1px solid #e2e8f0;
border-radius: 7px 7px 0 0;
}
.filters-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 1rem;
padding: 1rem;
}
.filter-field {
display: flex;
flex-direction: column;
gap: 0.25rem;
}
.filter-label {
font-size: 0.875rem;
font-weight: 500;
color: #4a5568;
margin-bottom: 0.25rem;
}
.filter-input {
padding: 0.5rem;
border: 1px solid #e2e8f0;
border-radius: 4px;
font-size: 0.875rem;
background: white;
transition: border-color 0.2s;
}
.filter-input:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 0 1px #667eea;
}
/* Selectable tags controls */
.selectable-tags-details {
border: 1px solid #e2e8f0;
border-radius: 8px;
background: white;
}
.selectable-tags-summary {
padding: 0.75rem 1rem;
font-weight: 600;
color: #4a5568;
cursor: pointer;
user-select: none;
transition: background-color 0.2s;
border-radius: 7px;
}
.selectable-tags-summary:hover {
background: #f8fafc;
}
.selectable-tags-details[open] .selectable-tags-summary {
border-bottom: 1px solid #e2e8f0;
border-radius: 7px 7px 0 0;
background: #f8fafc;
}
.tag-controls {
display: flex;
align-items: center;
gap: 0.5rem;
padding: 0.75rem 1rem;
border-bottom: 1px solid #e2e8f0;
background: #fafbfc;
}
.tag-control-btn {
background: #f7fafc;
border: 1px solid #e2e8f0;
border-radius: 4px;
padding: 0.25rem 0.5rem;
font-size: 0.75rem;
color: #4a5568;
cursor: pointer;
transition: all 0.2s;
}
.tag-control-btn:hover:not(:disabled) {
background: #edf2f7;
border-color: #cbd5e0;
}
.tag-control-btn:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.tag-count {
font-size: 0.75rem;
color: #718096;
font-weight: 500;
}
/* Tag search controls */
.tag-search-controls {
display: flex;
align-items: center;
gap: 1rem;
padding: 0.75rem 1rem;
background: #f8fafc;
border-bottom: 1px solid #e2e8f0;
}
.tag-search-input {
flex: 1;
padding: 0.375rem 0.5rem;
border: 1px solid #e2e8f0;
border-radius: 4px;
font-size: 0.875rem;
background: white;
transition: border-color 0.2s;
}
.tag-search-input:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 0 1px #667eea;
}
.filtered-count {
font-size: 0.75rem;
color: #718096;
font-weight: 500;
white-space: nowrap;
}
.tags-display-area {
padding: 1rem;
}
@media (max-width: 768px) {
.filters-grid {
grid-template-columns: 1fr;
gap: 0.5rem;
padding: 0.75rem;
}
.modality-filter-title {
padding: 0.5rem 0.75rem;
font-size: 0.9rem;
}
.tag-search-controls {
flex-direction: column;
align-items: stretch;
gap: 0.5rem;
}
} }

View File

@ -2,9 +2,9 @@ import { useEffect } from 'react'
import { BrowserRouter as Router, Routes, Route, Navigate, useNavigate, useLocation } from 'react-router-dom' import { BrowserRouter as Router, Routes, Route, Navigate, useNavigate, useLocation } from 'react-router-dom'
import './App.css' import './App.css'
import { useAuth } from './hooks/useAuth' import { useAuth } from '@/hooks/useAuth'
import { useOAuth } from './hooks/useOAuth' import { useOAuth } from '@/hooks/useOAuth'
import { Loading, LoginPrompt, AuthError, Dashboard, Search } from './components' import { Loading, LoginPrompt, AuthError, Dashboard, Search } from '@/components'
// AuthWrapper handles redirects based on auth state // AuthWrapper handles redirects based on auth state
const AuthWrapper = () => { const AuthWrapper = () => {
@ -31,7 +31,10 @@ const AuthWrapper = () => {
// Handle redirects based on auth state changes // Handle redirects based on auth state changes
useEffect(() => { useEffect(() => {
if (!isLoading) { if (!isLoading) {
if (isAuthenticated) { if (location.pathname === '/ui/logout') {
logout()
navigate('/ui/login', { replace: true })
} else if (isAuthenticated) {
// If authenticated and on login page, redirect to dashboard // If authenticated and on login page, redirect to dashboard
if (location.pathname === '/ui/login' || location.pathname === '/ui') { if (location.pathname === '/ui/login' || location.pathname === '/ui') {
navigate('/ui/dashboard', { replace: true }) navigate('/ui/dashboard', { replace: true })

View File

@ -1,5 +1,5 @@
import { Link } from 'react-router-dom' import { Link } from 'react-router-dom'
import { useMCP } from '../hooks/useMCP' import { useMCP } from '@/hooks/useMCP'
const Dashboard = ({ onLogout }) => { const Dashboard = ({ onLogout }) => {
const { listNotes } = useMCP() const { listNotes } = useMCP()

View File

@ -1,6 +1,4 @@
import React from 'react' const Loading = ({ message = "Loading..." }: { message?: string }) => {
const Loading = ({ message = "Loading..." }) => {
return ( return (
<div className="loading"> <div className="loading">
<h2>{message}</h2> <h2>{message}</h2>

View File

@ -1,6 +1,4 @@
import React from 'react' const AuthError = ({ error, onRetry }: { error: string, onRetry: () => void }) => {
const AuthError = ({ error, onRetry }) => {
return ( return (
<div className="error"> <div className="error">
<h2>Authentication Error</h2> <h2>Authentication Error</h2>

View File

@ -1,6 +1,4 @@
import React from 'react' const LoginPrompt = ({ onLogin }: { onLogin: () => void }) => {
const LoginPrompt = ({ onLogin }) => {
return ( return (
<div className="login-prompt"> <div className="login-prompt">
<h1>Memory App</h1> <h1>Memory App</h1>

View File

@ -1,5 +1,5 @@
export { default as Loading } from './Loading' export { default as Loading } from './Loading'
export { default as Dashboard } from './Dashboard' export { default as Dashboard } from './Dashboard'
export { default as Search } from './Search' export { default as Search } from './search'
export { default as LoginPrompt } from './auth/LoginPrompt' export { default as LoginPrompt } from './auth/LoginPrompt'
export { default as AuthError } from './auth/AuthError' export { default as AuthError } from './auth/AuthError'

View File

@ -0,0 +1,155 @@
import { FilterInput } from './FilterInput'
import { CollectionMetadata, SchemaArg } from '@/types/mcp'
// Pure helper functions for schema processing
const formatFieldLabel = (field: string): string =>
field.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase())
const shouldSkipField = (fieldName: string): boolean =>
['tags'].includes(fieldName)
const isCommonField = (fieldName: string): boolean =>
['size', 'filename', 'content_type', 'mail_message_id', 'sent_at', 'created_at', 'source_id'].includes(fieldName)
const createMinMaxFields = (fieldName: string, fieldConfig: any): [string, any][] => [
[`min_${fieldName}`, { ...fieldConfig, description: `Min ${fieldConfig.description}` }],
[`max_${fieldName}`, { ...fieldConfig, description: `Max ${fieldConfig.description}` }]
]
const createSizeFields = (): [string, any][] => [
['min_size', { type: 'int', description: 'Minimum size in bytes' }],
['max_size', { type: 'int', description: 'Maximum size in bytes' }]
]
const extractSchemaFields = (schema: Record<string, SchemaArg>, includeCommon = true): [string, SchemaArg][] =>
Object.entries(schema)
.filter(([fieldName]) => !shouldSkipField(fieldName))
.filter(([fieldName]) => includeCommon || !isCommonField(fieldName))
.flatMap(([fieldName, fieldConfig]) =>
['sent_at', 'published', 'created_at'].includes(fieldName)
? createMinMaxFields(fieldName, fieldConfig)
: [[fieldName, fieldConfig] as [string, SchemaArg]]
)
const getCommonFields = (schemas: Record<string, CollectionMetadata>, selectedModalities: string[]): [string, SchemaArg][] => {
const commonFieldsMap = new Map<string, SchemaArg>()
// Manually add created_at fields even if not in schema
createMinMaxFields('created_at', { type: 'datetime', description: 'Creation date' }).forEach(([field, config]) => {
commonFieldsMap.set(field, config)
})
selectedModalities.forEach(modality => {
const schema = schemas[modality].schema
if (!schema) return
Object.entries(schema).forEach(([fieldName, fieldConfig]) => {
if (isCommonField(fieldName)) {
if (['sent_at', 'created_at'].includes(fieldName)) {
createMinMaxFields(fieldName, fieldConfig).forEach(([field, config]) => {
commonFieldsMap.set(field, config)
})
} else if (fieldName === 'size') {
createSizeFields().forEach(([field, config]) => {
commonFieldsMap.set(field, config)
})
} else {
commonFieldsMap.set(fieldName, fieldConfig)
}
}
})
})
return Array.from(commonFieldsMap.entries())
}
const getModalityFields = (schemas: Record<string, CollectionMetadata>, selectedModalities: string[]): Record<string, [string, SchemaArg][]> => {
return selectedModalities.reduce((acc, modality) => {
const schema = schemas[modality].schema
if (!schema) return acc
const schemaFields = extractSchemaFields(schema, false) // Exclude common fields
if (schemaFields.length > 0) {
acc[modality] = schemaFields
}
return acc
}, {} as Record<string, [string, SchemaArg][]>)
}
interface DynamicFiltersProps {
schemas: Record<string, CollectionMetadata>
selectedModalities: string[]
filters: Record<string, SchemaArg>
onFilterChange: (field: string, value: SchemaArg) => void
}
export const DynamicFilters = ({
schemas,
selectedModalities,
filters,
onFilterChange
}: DynamicFiltersProps) => {
const commonFields = getCommonFields(schemas, selectedModalities)
const modalityFields = getModalityFields(schemas, selectedModalities)
if (commonFields.length === 0 && Object.keys(modalityFields).length === 0) {
return null
}
return (
<div className="search-option">
<label>Filters:</label>
<div className="modality-filters">
{/* Common/Document Properties Section */}
{commonFields.length > 0 && (
<details className="modality-filter-group" open>
<summary className="modality-filter-title">
Document Properties
</summary>
<div className="filters-grid">
{commonFields.map(([field, fieldConfig]: [string, SchemaArg]) => (
<div key={field} className="filter-field">
<label className="filter-label">
{formatFieldLabel(field)}:
</label>
<FilterInput
field={field}
fieldConfig={fieldConfig}
value={filters[field]}
onChange={onFilterChange}
/>
</div>
))}
</div>
</details>
)}
{/* Modality-specific sections */}
{Object.entries(modalityFields).map(([modality, fields]) => (
<details key={modality} className="modality-filter-group">
<summary className="modality-filter-title">
{formatFieldLabel(modality)} Specific
</summary>
<div className="filters-grid">
{fields.map(([field, fieldConfig]: [string, SchemaArg]) => (
<div key={field} className="filter-field">
<label className="filter-label">
{formatFieldLabel(field)}:
</label>
<FilterInput
field={field}
fieldConfig={fieldConfig}
value={filters[field]}
onChange={onFilterChange}
/>
</div>
))}
</div>
</details>
))}
</div>
</div>
)
}

View File

@ -0,0 +1,53 @@
// Pure helper functions
const isDateField = (field: string): boolean =>
field.includes('sent_at') || field.includes('published') || field.includes('created_at')
const isNumberField = (field: string, fieldConfig: any): boolean =>
fieldConfig.type?.includes('int') || field.includes('size')
const parseNumberValue = (value: string): number | null =>
value ? parseInt(value) : null
interface FilterInputProps {
field: string
fieldConfig: any
value: any
onChange: (field: string, value: any) => void
}
export const FilterInput = ({ field, fieldConfig, value, onChange }: FilterInputProps) => {
const inputProps = {
value: value || '',
className: "filter-input"
}
if (isNumberField(field, fieldConfig)) {
return (
<input
{...inputProps}
type="number"
onChange={(e) => onChange(field, parseNumberValue(e.target.value))}
placeholder={fieldConfig.description}
/>
)
}
if (isDateField(field)) {
return (
<input
{...inputProps}
type="date"
onChange={(e) => onChange(field, e.target.value || null)}
/>
)
}
return (
<input
{...inputProps}
type="text"
onChange={(e) => onChange(field, e.target.value || null)}
placeholder={fieldConfig.description}
/>
)
}

View File

@ -0,0 +1,69 @@
import { useState } from 'react'
import { useNavigate } from 'react-router-dom'
import { useMCP } from '@/hooks/useMCP'
import Loading from '@/components/Loading'
import { SearchResult } from './results'
import SearchForm, { SearchParams } from './SearchForm'
const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => {
if (isLoading) {
return <Loading message="Searching..." />
}
return (
<div className="search-results">
{results.length > 0 && (
<div className="results-count">
Found {results.length} result{results.length !== 1 ? 's' : ''}
</div>
)}
{results.map((result, index) => <SearchResult key={index} result={result} />)}
{results.length === 0 && (
<div className="no-results">
No results found
</div>
)}
</div>
)
}
const Search = () => {
const navigate = useNavigate()
const [results, setResults] = useState([])
const [isLoading, setIsLoading] = useState(false)
const { searchKnowledgeBase } = useMCP()
const handleSearch = async (params: SearchParams) => {
if (!params.query.trim()) return
setIsLoading(true)
try {
console.log(params)
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
setResults(searchResults || [])
} catch (error) {
console.error('Search error:', error)
setResults([])
} finally {
setIsLoading(false)
}
}
return (
<div className="search-view">
<header className="search-header">
<button onClick={() => navigate('/ui/dashboard')} className="back-btn">
Back to Dashboard
</button>
<h2>🔍 Search Knowledge Base</h2>
</header>
<SearchForm isLoading={isLoading} onSearch={handleSearch} />
<SearchResults results={results} isLoading={isLoading} />
</div>
)
}
export default Search

View File

@ -0,0 +1,151 @@
import { useMCP } from '@/hooks/useMCP'
import { useEffect, useState } from 'react'
import { DynamicFilters } from './DynamicFilters'
import { SelectableTags } from './SelectableTags'
import { CollectionMetadata } from '@/types/mcp'
type Filter = {
tags?: string[]
source_ids?: string[]
[key: string]: any
}
export interface SearchParams {
query: string
previews: boolean
modalities: string[]
filters: Filter
limit: number
}
interface SearchFormProps {
isLoading: boolean
onSearch: (params: SearchParams) => void
}
// Pure helper functions for SearchForm
const createFlags = (items: string[], defaultValue = false): Record<string, boolean> =>
items.reduce((acc, item) => ({ ...acc, [item]: defaultValue }), {})
const getSelectedItems = (items: Record<string, boolean>): string[] =>
Object.entries(items).filter(([_, selected]) => selected).map(([key]) => key)
const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
Object.entries(filters)
.filter(([_, value]) => value !== null && value !== '' && value !== undefined)
.reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {})
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
const [query, setQuery] = useState('')
const [previews, setPreviews] = useState(false)
const [modalities, setModalities] = useState<Record<string, boolean>>({})
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
const [tags, setTags] = useState<Record<string, boolean>>({})
const [dynamicFilters, setDynamicFilters] = useState<Record<string, any>>({})
const [limit, setLimit] = useState(10)
const { getMetadataSchemas, getTags } = useMCP()
useEffect(() => {
const setupFilters = async () => {
const [schemas, tags] = await Promise.all([
getMetadataSchemas(),
getTags()
])
setSchemas(schemas)
setModalities(createFlags(Object.keys(schemas), true))
setTags(createFlags(tags))
}
setupFilters()
}, [getMetadataSchemas, getTags])
const handleFilterChange = (field: string, value: any) =>
setDynamicFilters(prev => ({ ...prev, [field]: value }))
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault()
onSearch({
query,
previews,
modalities: getSelectedItems(modalities),
filters: {
tags: getSelectedItems(tags),
...cleanFilters(dynamicFilters)
},
limit
})
}
return (
<form onSubmit={handleSubmit} className="search-form">
<div className="search-input-group">
<input
type="text"
value={query}
onChange={(e) => setQuery(e.target.value)}
placeholder="Search your knowledge base..."
className="search-input"
required
/>
<button type="submit" disabled={isLoading} className="search-btn">
{isLoading ? 'Searching...' : 'Search'}
</button>
</div>
<div className="search-options">
<div className="search-option">
<label>
<input
type="checkbox"
checked={previews}
onChange={(e) => setPreviews(e.target.checked)}
/>
Include content previews
</label>
</div>
<SelectableTags
title="Modalities"
className="modality-checkboxes"
tags={modalities}
onSelect={(tag, selected) => setModalities({ ...modalities, [tag]: selected })}
onBatchUpdate={(updates) => setModalities(updates)}
/>
<SelectableTags
title="Tags"
className="tags-container"
tags={tags}
onSelect={(tag, selected) => setTags({ ...tags, [tag]: selected })}
onBatchUpdate={(updates) => setTags(updates)}
searchable={true}
/>
<DynamicFilters
schemas={schemas}
selectedModalities={getSelectedItems(modalities)}
filters={dynamicFilters}
onFilterChange={handleFilterChange}
/>
<div className="search-option">
<label>
Max Results:
<input
type="number"
value={limit}
onChange={(e) => setLimit(parseInt(e.target.value) || 10)}
min={1}
max={100}
className="limit-input"
/>
</label>
</div>
</div>
</form>
)
}
export default SearchForm

View File

@ -0,0 +1,136 @@
interface SelectableTagProps {
tag: string
selected: boolean
onSelect: (tag: string, selected: boolean) => void
}
const SelectableTag = ({ tag, selected, onSelect }: SelectableTagProps) => {
return (
<span
className={`tag ${selected ? 'selected' : ''}`}
onClick={() => onSelect(tag, !selected)}
>
{tag}
</span>
)
}
import { useState } from 'react'
interface SelectableTagsProps {
title: string
className: string
tags: Record<string, boolean>
onSelect: (tag: string, selected: boolean) => void
onBatchUpdate?: (updates: Record<string, boolean>) => void
searchable?: boolean
}
export const SelectableTags = ({ title, className, tags, onSelect, onBatchUpdate, searchable = false }: SelectableTagsProps) => {
const [searchTerm, setSearchTerm] = useState('')
const handleSelectAll = () => {
if (onBatchUpdate) {
const updates = Object.keys(tags).reduce((acc, tag) => {
acc[tag] = true
return acc
}, {} as Record<string, boolean>)
onBatchUpdate(updates)
} else {
// Fallback to individual updates (though this won't work well)
Object.keys(tags).forEach(tag => {
if (!tags[tag]) {
onSelect(tag, true)
}
})
}
}
const handleDeselectAll = () => {
if (onBatchUpdate) {
const updates = Object.keys(tags).reduce((acc, tag) => {
acc[tag] = false
return acc
}, {} as Record<string, boolean>)
onBatchUpdate(updates)
} else {
// Fallback to individual updates (though this won't work well)
Object.keys(tags).forEach(tag => {
if (tags[tag]) {
onSelect(tag, false)
}
})
}
}
// Filter tags based on search term
const filteredTags = Object.entries(tags).filter(([tag, selected]) => {
return !searchTerm || tag.toLowerCase().includes(searchTerm.toLowerCase())
})
const selectedCount = Object.values(tags).filter(Boolean).length
const totalCount = Object.keys(tags).length
const filteredSelectedCount = filteredTags.filter(([_, selected]) => selected).length
const filteredTotalCount = filteredTags.length
const allSelected = selectedCount === totalCount
const noneSelected = selectedCount === 0
return (
<div className="search-option">
<details className="selectable-tags-details">
<summary className="selectable-tags-summary">
{title} ({selectedCount} selected)
</summary>
<div className="tag-controls">
<button
type="button"
className="tag-control-btn"
onClick={handleSelectAll}
disabled={allSelected}
>
All
</button>
<button
type="button"
className="tag-control-btn"
onClick={handleDeselectAll}
disabled={noneSelected}
>
None
</button>
<span className="tag-count">
({selectedCount}/{totalCount})
</span>
</div>
{searchable && (
<div className="tag-search-controls">
<input
type="text"
placeholder={`Search ${title.toLowerCase()}...`}
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="tag-search-input"
/>
{searchTerm && (
<span className="filtered-count">
Showing {filteredSelectedCount}/{filteredTotalCount}
</span>
)}
</div>
)}
<div className={`${className} tags-display-area`}>
{filteredTags.map(([tag, selected]: [string, boolean]) => (
<SelectableTag
key={tag}
tag={tag}
selected={selected}
onSelect={onSelect}
/>
))}
</div>
</details>
</div>
)
}

View File

@ -0,0 +1,3 @@
import Search from './Search'
export default Search

View File

@ -1,10 +1,8 @@
import React, { useState, useEffect } from 'react' import { useState, useEffect } from 'react'
import { useNavigate } from 'react-router-dom'
import ReactMarkdown from 'react-markdown' import ReactMarkdown from 'react-markdown'
import { useMCP } from '../hooks/useMCP' import { useMCP } from '@/hooks/useMCP'
import Loading from './Loading'
type SearchItem = { export type SearchItem = {
filename: string filename: string
content: string content: string
chunks: any[] chunks: any[]
@ -13,7 +11,7 @@ type SearchItem = {
metadata: any metadata: any
} }
const Tag = ({ tags }: { tags: string[] }) => { export const Tag = ({ tags }: { tags: string[] }) => {
return ( return (
<div className="tags"> <div className="tags">
{tags?.map((tag: string, index: number) => ( {tags?.map((tag: string, index: number) => (
@ -23,11 +21,12 @@ const Tag = ({ tags }: { tags: string[] }) => {
) )
} }
const TextResult = ({ filename, content, chunks, tags }: SearchItem) => { export const TextResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
return ( return (
<div className="search-result-card"> <div className="search-result-card">
<h4>{filename || 'Untitled'}</h4> <h4>{filename || 'Untitled'}</h4>
<Tag tags={tags} /> <Tag tags={tags} />
<Metadata metadata={metadata} />
<p className="result-content">{content || 'No content available'}</p> <p className="result-content">{content || 'No content available'}</p>
{chunks && chunks.length > 0 && ( {chunks && chunks.length > 0 && (
<details className="result-chunks"> <details className="result-chunks">
@ -44,7 +43,7 @@ const TextResult = ({ filename, content, chunks, tags }: SearchItem) => {
) )
} }
const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => { export const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
return ( return (
<div className="search-result-card"> <div className="search-result-card">
<h4>{filename || 'Untitled'}</h4> <h4>{filename || 'Untitled'}</h4>
@ -70,7 +69,7 @@ const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchIte
) )
} }
const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => { export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
const title = metadata?.title || filename || 'Untitled' const title = metadata?.title || filename || 'Untitled'
const { fetchFile } = useMCP() const { fetchFile } = useMCP()
const [mime_type, setMimeType] = useState<string>() const [mime_type, setMimeType] = useState<string>()
@ -95,7 +94,7 @@ const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => {
) )
} }
const Metadata = ({ metadata }: { metadata: any }) => { export const Metadata = ({ metadata }: { metadata: any }) => {
if (!metadata) return null if (!metadata) return null
return ( return (
<div className="metadata"> <div className="metadata">
@ -108,7 +107,7 @@ const Metadata = ({ metadata }: { metadata: any }) => {
) )
} }
const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => { export const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
return ( return (
<div className="search-result-card"> <div className="search-result-card">
<h4>{filename || 'Untitled'}</h4> <h4>{filename || 'Untitled'}</h4>
@ -125,7 +124,7 @@ const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
) )
} }
const EmailResult = ({ content, tags, metadata }: SearchItem) => { export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
return ( return (
<div className="search-result-card"> <div className="search-result-card">
<h4>{metadata?.title || metadata?.subject || 'Untitled'}</h4> <h4>{metadata?.title || metadata?.subject || 'Untitled'}</h4>
@ -138,7 +137,7 @@ const EmailResult = ({ content, tags, metadata }: SearchItem) => {
) )
} }
const SearchResult = ({ result }: { result: SearchItem }) => { export const SearchResult = ({ result }: { result: SearchItem }) => {
if (result.mime_type.startsWith('image/')) { if (result.mime_type.startsWith('image/')) {
return <ImageResult {...result} /> return <ImageResult {...result} />
} }
@ -158,86 +157,4 @@ const SearchResult = ({ result }: { result: SearchItem }) => {
return null return null
} }
const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => { export default SearchResult
if (isLoading) {
return <Loading message="Searching..." />
}
return (
<div className="search-results">
{results.length > 0 && (
<div className="results-count">
Found {results.length} result{results.length !== 1 ? 's' : ''}
</div>
)}
{results.map((result, index) => <SearchResult key={index} result={result} />)}
{results.length === 0 && (
<div className="no-results">
No results found
</div>
)}
</div>
)
}
const SearchForm = ({ isLoading, onSearch }: { isLoading: boolean, onSearch: (query: string) => void }) => {
const [query, setQuery] = useState('')
return (
<form onSubmit={(e) => {
e.preventDefault()
onSearch(query)
}} className="search-form">
<div className="search-input-group">
<input
type="text"
value={query}
onChange={(e) => setQuery(e.target.value)}
placeholder="Search your knowledge base..."
className="search-input"
/>
<button type="submit" disabled={isLoading} className="search-btn">
{isLoading ? 'Searching...' : 'Search'}
</button>
</div>
</form>
)
}
const Search = () => {
const navigate = useNavigate()
const [results, setResults] = useState([])
const [isLoading, setIsLoading] = useState(false)
const { searchKnowledgeBase } = useMCP()
const handleSearch = async (query: string) => {
if (!query.trim()) return
setIsLoading(true)
try {
const searchResults = await searchKnowledgeBase(query)
setResults(searchResults || [])
} catch (error) {
console.error('Search error:', error)
setResults([])
} finally {
setIsLoading(false)
}
}
return (
<div className="search-view">
<header className="search-header">
<button onClick={() => navigate('/ui/dashboard')} className="back-btn">
Back to Dashboard
</button>
<h2>🔍 Search Knowledge Base</h2>
</header>
<SearchForm isLoading={isLoading} onSearch={handleSearch} />
<SearchResults results={results} isLoading={isLoading} />
</div>
)
}
export default Search

View File

@ -4,20 +4,20 @@ const SERVER_URL = import.meta.env.VITE_SERVER_URL || 'http://localhost:8000'
const SESSION_COOKIE_NAME = import.meta.env.VITE_SESSION_COOKIE_NAME || 'session_id' const SESSION_COOKIE_NAME = import.meta.env.VITE_SESSION_COOKIE_NAME || 'session_id'
// Cookie utilities // Cookie utilities
const getCookie = (name) => { const getCookie = (name: string) => {
const value = `; ${document.cookie}` const value = `; ${document.cookie}`
const parts = value.split(`; ${name}=`) const parts = value.split(`; ${name}=`)
if (parts.length === 2) return parts.pop().split(';').shift() if (parts.length === 2) return parts.pop().split(';').shift()
return null return null
} }
const setCookie = (name, value, days = 30) => { const setCookie = (name: string, value: string, days = 30) => {
const expires = new Date() const expires = new Date()
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000) expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax` document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
} }
const deleteCookie = (name) => { const deleteCookie = (name: string) => {
document.cookie = `${name}=;expires=Thu, 01 Jan 1970 00:00:01 GMT;path=/` document.cookie = `${name}=;expires=Thu, 01 Jan 1970 00:00:01 GMT;path=/`
} }
@ -68,6 +68,7 @@ export const useAuth = () => {
deleteCookie('access_token') deleteCookie('access_token')
deleteCookie('refresh_token') deleteCookie('refresh_token')
deleteCookie(SESSION_COOKIE_NAME) deleteCookie(SESSION_COOKIE_NAME)
localStorage.removeItem('oauth_client_id')
setIsAuthenticated(false) setIsAuthenticated(false)
}, []) }, [])
@ -110,7 +111,7 @@ export const useAuth = () => {
}, [logout]) }, [logout])
// Make authenticated API calls with automatic token refresh // Make authenticated API calls with automatic token refresh
const apiCall = useCallback(async (endpoint, options = {}) => { const apiCall = useCallback(async (endpoint: string, options: RequestInit = {}) => {
let accessToken = getCookie('access_token') let accessToken = getCookie('access_token')
if (!accessToken) { if (!accessToken) {
@ -122,7 +123,7 @@ export const useAuth = () => {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
} }
const requestOptions = { const requestOptions: RequestInit & { headers: Record<string, string> } = {
...options, ...options,
headers: { ...defaultHeaders, ...options.headers }, headers: { ...defaultHeaders, ...options.headers },
} }

View File

@ -1,5 +1,5 @@
import { useEffect, useCallback } from 'react' import { useEffect, useCallback } from 'react'
import { useAuth } from './useAuth' import { useAuth } from '@/hooks/useAuth'
const parseServerSentEvents = async (response: Response): Promise<any> => { const parseServerSentEvents = async (response: Response): Promise<any> => {
const reader = response.body?.getReader() const reader = response.body?.getReader()
@ -91,10 +91,10 @@ const parseJsonRpcResponse = async (response: Response): Promise<any> => {
} }
export const useMCP = () => { export const useMCP = () => {
const { apiCall, isAuthenticated, isLoading, checkAuth } = useAuth() const { apiCall, checkAuth } = useAuth()
const mcpCall = useCallback(async (path: string, method: string, params: any = {}) => { const mcpCall = useCallback(async (method: string, params: any = {}) => {
const response = await apiCall(`/mcp${path}`, { const response = await apiCall(`/mcp/${method}`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Accept': 'application/json, text/event-stream', 'Accept': 'application/json, text/event-stream',
@ -118,22 +118,46 @@ export const useMCP = () => {
if (resp?.result?.isError) { if (resp?.result?.isError) {
throw new Error(resp?.result?.content[0].text) throw new Error(resp?.result?.content[0].text)
} }
return resp?.result?.content.map((item: any) => JSON.parse(item.text)) return resp?.result?.content.map((item: any) => {
try {
return JSON.parse(item.text)
} catch (e) {
return item.text
}
})
}, [apiCall]) }, [apiCall])
const listNotes = useCallback(async (path: string = "/") => { const listNotes = useCallback(async (path: string = "/") => {
return await mcpCall('/note_files', 'note_files', { path }) return await mcpCall('note_files', { path })
}, [mcpCall]) }, [mcpCall])
const fetchFile = useCallback(async (filename: string) => { const fetchFile = useCallback(async (filename: string) => {
return await mcpCall('/fetch_file', 'fetch_file', { filename }) return await mcpCall('fetch_file', { filename })
}, [mcpCall]) }, [mcpCall])
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10) => { const getTags = useCallback(async () => {
return await mcpCall('/search_knowledge_base', 'search_knowledge_base', { return await mcpCall('get_all_tags')
}, [mcpCall])
const getSubjects = useCallback(async () => {
return await mcpCall('get_all_subjects')
}, [mcpCall])
const getObservationTypes = useCallback(async () => {
return await mcpCall('get_all_observation_types')
}, [mcpCall])
const getMetadataSchemas = useCallback(async () => {
return (await mcpCall('get_metadata_schemas'))[0]
}, [mcpCall])
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
return await mcpCall('search_knowledge_base', {
query, query,
filters,
modalities,
previews, previews,
limit limit,
}) })
}, [mcpCall]) }, [mcpCall])
@ -146,5 +170,9 @@ export const useMCP = () => {
fetchFile, fetchFile,
listNotes, listNotes,
searchKnowledgeBase, searchKnowledgeBase,
getTags,
getSubjects,
getObservationTypes,
getMetadataSchemas,
} }
} }

View File

@ -14,7 +14,7 @@ const generateCodeVerifier = () => {
.replace(/=/g, '') .replace(/=/g, '')
} }
const generateCodeChallenge = async (verifier) => { const generateCodeChallenge = async (verifier: string) => {
const data = new TextEncoder().encode(verifier) const data = new TextEncoder().encode(verifier)
const digest = await crypto.subtle.digest('SHA-256', data) const digest = await crypto.subtle.digest('SHA-256', data)
return btoa(String.fromCharCode(...new Uint8Array(digest))) return btoa(String.fromCharCode(...new Uint8Array(digest)))
@ -33,7 +33,7 @@ const generateState = () => {
} }
// Storage utilities // Storage utilities
const setCookie = (name, value, days = 30) => { const setCookie = (name: string, value: string, days = 30) => {
const expires = new Date() const expires = new Date()
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000) expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax` document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`

View File

@ -1,6 +1,6 @@
import { StrictMode } from 'react' import { StrictMode } from 'react'
import { createRoot } from 'react-dom/client' import { createRoot } from 'react-dom/client'
import App from './App.jsx' import App from '@/App.jsx'
createRoot(document.getElementById('root')).render( createRoot(document.getElementById('root')).render(
<StrictMode> <StrictMode>

View File

@ -0,0 +1,9 @@
export type CollectionMetadata = {
schema: Record<string, SchemaArg>
size: number
}
export type SchemaArg = {
type: string
description: string
}

40
frontend/tsconfig.json Normal file
View File

@ -0,0 +1,40 @@
{
"compilerOptions": {
"target": "ES2020",
"useDefineForClassFields": true,
"lib": [
"ES2020",
"DOM",
"DOM.Iterable"
],
"module": "ESNext",
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx",
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noFallthroughCasesInSwitch": true,
/* Absolute imports */
"baseUrl": ".",
"paths": {
"@/*": [
"src/*"
]
}
},
"include": [
"src"
],
"references": [
{
"path": "./tsconfig.node.json"
}
]
}

View File

@ -0,0 +1,12 @@
{
"compilerOptions": {
"composite": true,
"skipLibCheck": true,
"module": "ESNext",
"moduleResolution": "bundler",
"allowSyntheticDefaultImports": true
},
"include": [
"vite.config.js"
]
}

View File

@ -1,8 +1,14 @@
import { defineConfig } from 'vite' import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react' import react from '@vitejs/plugin-react'
import path from 'path'
// https://vite.dev/config/ // https://vite.dev/config/
export default defineConfig({ export default defineConfig({
plugins: [react()], plugins: [react()],
base: '/ui/', base: '/ui/',
resolve: {
alias: {
'@': path.resolve(__dirname, './src'),
},
},
}) })

View File

@ -1,2 +1,3 @@
import memory.api.MCP.manifest import memory.api.MCP.manifest
import memory.api.MCP.memory import memory.api.MCP.memory
import memory.api.MCP.metadata

View File

@ -127,7 +127,6 @@ async def handle_login(request: Request):
return login_form(request, oauth_params, "Invalid email or password") return login_form(request, oauth_params, "Invalid email or password")
redirect_url = await oauth_provider.complete_authorization(oauth_params, user) redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
print("redirect_url", redirect_url)
if redirect_url.startswith("http://anysphere.cursor-retrieval"): if redirect_url.startswith("http://anysphere.cursor-retrieval"):
redirect_url = redirect_url.replace("http://", "cursor://") redirect_url = redirect_url.replace("http://", "cursor://")
return RedirectResponse(url=redirect_url, status_code=302) return RedirectResponse(url=redirect_url, status_code=302)

View File

@ -3,23 +3,25 @@ MCP tools for the epistemic sparring partner system.
""" """
import logging import logging
import mimetypes
import pathlib import pathlib
from datetime import datetime, timezone from datetime import datetime, timezone
import mimetypes from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Text, func from sqlalchemy import Text
from sqlalchemy import cast as sql_cast from sqlalchemy import cast as sql_cast
from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import ARRAY
from memory.api.MCP.tools import mcp
from memory.api.search.search import SearchFilters, search from memory.api.search.search import SearchFilters, search
from memory.common import extract, settings from memory.common import extract, settings
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
from memory.common.celery_app import app as celery_app
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
from memory.common.db.connection import make_session from memory.common.db.connection import make_session
from memory.common.db.models import AgentObservation, SourceItem from memory.common.db.models import SourceItem, AgentObservation
from memory.common.formatters import observation from memory.common.formatters import observation
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
from memory.api.MCP.tools import mcp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,11 +49,13 @@ def filter_observation_source_ids(
return source_ids return source_ids
def filter_source_ids( def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] | None:
modalities: set[str], if source_ids := filters.get("source_ids"):
tags: list[str] | None = None, return source_ids
):
if not tags: tags = filters.get("tags")
size = filters.get("size")
if not (tags or size):
return None return None
with make_session() as session: with make_session() as session:
@ -62,6 +66,8 @@ def filter_source_ids(
items_query = items_query.filter( items_query = items_query.filter(
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))), SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
) )
if size:
items_query = items_query.filter(SourceItem.size == size)
if modalities: if modalities:
items_query = items_query.filter(SourceItem.modality.in_(modalities)) items_query = items_query.filter(SourceItem.modality.in_(modalities))
source_ids = [item.id for item in items_query.all()] source_ids = [item.id for item in items_query.all()]
@ -69,51 +75,12 @@ def filter_source_ids(
return source_ids return source_ids
@mcp.tool()
async def get_all_tags() -> list[str]:
"""
Get all unique tags used across the entire knowledge base.
Returns sorted list of tags from both observations and content.
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@mcp.tool()
async def get_all_subjects() -> list[str]:
"""
Get all unique subjects from observations about the user.
Returns sorted list of subject identifiers used in observations.
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@mcp.tool()
async def get_all_observation_types() -> list[str]:
"""
Get all observation types that have been used.
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)
@mcp.tool() @mcp.tool()
async def search_knowledge_base( async def search_knowledge_base(
query: str, query: str,
previews: bool = False, filters: dict[str, Any],
modalities: set[str] = set(), modalities: set[str] = set(),
tags: list[str] = [], previews: bool = False,
limit: int = 10, limit: int = 10,
) -> list[dict]: ) -> list[dict]:
""" """
@ -125,7 +92,7 @@ async def search_knowledge_base(
query: Natural language search query - be descriptive about what you're looking for query: Natural language search query - be descriptive about what you're looking for
previews: Include actual content in results - when false only a snippet is returned previews: Include actual content in results - when false only a snippet is returned
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all) modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
tags: Filter by tags - content must have at least one matching tag filters: Filter by tags, source_ids, etc.
limit: Max results (1-100) limit: Max results (1-100)
Returns: List of search results with id, score, chunks, content, filename Returns: List of search results with id, score, chunks, content, filename
@ -137,6 +104,9 @@ async def search_knowledge_base(
modalities = set(ALL_COLLECTIONS.keys()) modalities = set(ALL_COLLECTIONS.keys())
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
search_filters = SearchFilters(**filters)
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
upload_data = extract.extract_text(query) upload_data = extract.extract_text(query)
results = await search( results = await search(
upload_data, upload_data,
@ -145,10 +115,7 @@ async def search_knowledge_base(
limit=limit, limit=limit,
min_text_score=0.4, min_text_score=0.4,
min_multimodal_score=0.25, min_multimodal_score=0.25,
filters=SearchFilters( filters=search_filters,
tags=tags,
source_ids=filter_source_ids(tags=tags, modalities=modalities),
),
) )
return [result.model_dump() for result in results] return [result.model_dump() for result in results]

View File

@ -0,0 +1,119 @@
import logging
from collections import defaultdict
from typing import Annotated, TypedDict, get_args, get_type_hints
from memory.common import qdrant
from sqlalchemy import func
from memory.api.MCP.tools import mcp
from memory.common.db.connection import make_session
from memory.common.db.models import SourceItem
from memory.common.db.models.source_items import AgentObservation
logger = logging.getLogger(__name__)
class SchemaArg(TypedDict):
type: str | None
description: str | None
class CollectionMetadata(TypedDict):
schema: dict[str, SchemaArg]
size: int
def from_annotation(annotation: Annotated) -> SchemaArg | None:
try:
type_, description = get_args(annotation)
type_str = str(type_)
if type_str.startswith("typing."):
type_str = type_str[7:]
elif len((parts := type_str.split("'"))) > 1:
type_str = parts[1]
return SchemaArg(type=type_str, description=description)
except IndexError:
logger.error(f"Error from annotation: {annotation}")
return None
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
if not hasattr(klass, "as_payload"):
return {}
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
return {}
return {
name: schema
for name, arg in payload_type.__annotations__.items()
if (schema := from_annotation(arg))
}
@mcp.tool()
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
"""Get the metadata schema for each collection used in the knowledge base.
These schemas can be used to filter the knowledge base.
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
Example:
```
{
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
}
"""
client = qdrant.get_qdrant_client()
sizes = qdrant.get_collection_sizes(client)
schemas = defaultdict(dict)
for klass in SourceItem.__subclasses__():
for collection in klass.get_collections():
schemas[collection].update(get_schema(klass))
return {
collection: CollectionMetadata(schema=schema, size=size)
for collection, schema in schemas.items()
if (size := sizes.get(collection))
}
@mcp.tool()
async def get_all_tags() -> list[str]:
"""Get all unique tags used across the entire knowledge base.
Returns sorted list of tags from both observations and content.
"""
with make_session() as session:
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
return sorted({row[0] for row in tags_query if row[0] is not None})
@mcp.tool()
async def get_all_subjects() -> list[str]:
"""Get all unique subjects from observations about the user.
Returns sorted list of subject identifiers used in observations.
"""
with make_session() as session:
return sorted(
r.subject for r in session.query(AgentObservation.subject).distinct()
)
@mcp.tool()
async def get_all_observation_types() -> list[str]:
"""Get all observation types that have been used.
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
"""
with make_session() as session:
return sorted(
{
r.observation_type
for r in session.query(AgentObservation.observation_type).distinct()
if r.observation_type is not None
}
)

View File

@ -71,6 +71,13 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
# SQLAdmin setup with OAuth protection # SQLAdmin setup with OAuth protection
engine = get_engine() engine = get_engine()
admin = Admin(app, engine) admin = Admin(app, engine)
admin.app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # [settings.SERVER_URL],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup admin with OAuth protection using existing OAuth provider # Setup admin with OAuth protection using existing OAuth provider
setup_admin(admin) setup_admin(admin)

View File

@ -1,7 +1,7 @@
import base64 import base64
import io import io
import logging import logging
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, cast
import qdrant_client import qdrant_client
from PIL import Image from PIL import Image
@ -90,13 +90,71 @@ def query_chunks(
} }
def merge_range_filter(
filters: list[dict[str, Any]], key: str, val: Any
) -> list[dict[str, Any]]:
direction, field = key.split("_", maxsplit=1)
item = next((f for f in filters if f["key"] == field), None)
if not item:
item = {"key": field, "range": {}}
filters.append(item)
if direction == "min":
item["range"]["gte"] = val
elif direction == "max":
item["range"]["lte"] = val
return filters
def merge_filters(
filters: list[dict[str, Any]], key: str, val: Any
) -> list[dict[str, Any]]:
if not val and val != 0:
return filters
list_filters = ["tags", "recipients", "observation_types", "authors"]
range_filters = [
"min_sent_at",
"max_sent_at",
"min_published",
"max_published",
"min_size",
"max_size",
"min_created_at",
"max_created_at",
]
if key in list_filters:
filters.append({"key": key, "match": {"any": val}})
elif key in range_filters:
return merge_range_filter(filters, key, val)
elif key == "min_confidences":
confidence_filters = [
{
"key": f"confidence.{confidence_type}",
"range": {"gte": min_confidence_score},
}
for confidence_type, min_confidence_score in cast(dict, val).items()
]
filters.extend(confidence_filters)
elif key == "source_ids":
filters.append({"key": "id", "match": {"any": val}})
else:
filters.append({"key": key, "match": val})
return filters
async def search_embeddings( async def search_embeddings(
data: list[extract.DataChunk], data: list[extract.DataChunk],
previews: Optional[bool] = False, previews: Optional[bool] = False,
modalities: set[str] = set(), modalities: set[str] = set(),
limit: int = 10, limit: int = 10,
min_score: float = 0.3, min_score: float = 0.3,
filters: SearchFilters = SearchFilters(), filters: SearchFilters = {},
multimodal: bool = False, multimodal: bool = False,
) -> list[tuple[SourceData, AnnotatedChunk]]: ) -> list[tuple[SourceData, AnnotatedChunk]]:
""" """
@ -111,27 +169,11 @@ async def search_embeddings(
- filters: Filters to apply to the search results - filters: Filters to apply to the search results
- multimodal: Whether to search in multimodal collections - multimodal: Whether to search in multimodal collections
""" """
query_filters = {"must": []} search_filters = []
for key, val in filters.items():
# Handle structured confidence filtering search_filters = merge_filters(search_filters, key, val)
if min_confidences := filters.get("min_confidences"):
confidence_filters = [
{
"key": f"confidence.{confidence_type}",
"range": {"gte": min_confidence_score},
}
for confidence_type, min_confidence_score in min_confidences.items()
]
query_filters["must"].extend(confidence_filters)
if tags := filters.get("tags"):
query_filters["must"].append({"key": "tags", "match": {"any": tags}})
if observation_types := filters.get("observation_types"):
query_filters["must"].append(
{"key": "observation_type", "match": {"any": observation_types}}
)
print(search_filters)
client = qdrant.get_qdrant_client() client = qdrant.get_qdrant_client()
results = query_chunks( results = query_chunks(
client, client,
@ -140,7 +182,7 @@ async def search_embeddings(
embedding.embed_text if not multimodal else embedding.embed_mixed, embedding.embed_text if not multimodal else embedding.embed_mixed,
min_score=min_score, min_score=min_score,
limit=limit, limit=limit,
filters=query_filters if query_filters["must"] else None, filters={"must": search_filters} if search_filters else None,
) )
search_results = {k: results.get(k, []) for k in modalities} search_results = {k: results.get(k, []) for k in modalities}

View File

@ -17,6 +17,7 @@ from memory.common.collections import (
MULTIMODAL_COLLECTIONS, MULTIMODAL_COLLECTIONS,
TEXT_COLLECTIONS, TEXT_COLLECTIONS,
) )
from memory.common import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,51 +45,57 @@ async def search(
- List of search results sorted by score - List of search results sorted by score
""" """
allowed_modalities = modalities & ALL_COLLECTIONS.keys() allowed_modalities = modalities & ALL_COLLECTIONS.keys()
print(allowed_modalities)
text_embeddings_results = with_timeout( searches = []
search_embeddings( if settings.ENABLE_EMBEDDING_SEARCH:
data, searches = [
previews, with_timeout(
allowed_modalities & TEXT_COLLECTIONS, search_embeddings(
limit, data,
min_text_score, previews,
filters, allowed_modalities & TEXT_COLLECTIONS,
multimodal=False, limit,
), min_text_score,
timeout, filters,
) multimodal=False,
multimodal_embeddings_results = with_timeout( ),
search_embeddings( timeout,
data, ),
previews, with_timeout(
allowed_modalities & MULTIMODAL_COLLECTIONS, search_embeddings(
limit, data,
min_multimodal_score, previews,
filters, allowed_modalities & MULTIMODAL_COLLECTIONS,
multimodal=True, limit,
), min_multimodal_score,
timeout, filters,
) multimodal=True,
bm25_results = with_timeout( ),
search_bm25( timeout,
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]), ),
modalities, ]
limit=limit, if settings.ENABLE_BM25_SEARCH:
filters=filters, searches.append(
), with_timeout(
timeout, search_bm25(
) " ".join(
[c for chunk in data for c in chunk.data if isinstance(c, str)]
),
modalities,
limit=limit,
filters=filters,
),
timeout,
)
)
results = await asyncio.gather( search_results = await asyncio.gather(*searches, return_exceptions=False)
text_embeddings_results, all_results = []
multimodal_embeddings_results, for results in search_results:
bm25_results, if len(all_results) >= limit:
return_exceptions=False, break
) all_results.extend(results)
text_results, multi_results, bm25_results = results
all_results = text_results + multi_results
if len(all_results) < limit:
all_results += bm25_results
results = group_chunks(all_results, previews or False) results = group_chunks(all_results, previews or False)
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True) return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)

View File

@ -65,9 +65,9 @@ class SearchResult(BaseModel):
class SearchFilters(TypedDict): class SearchFilters(TypedDict):
subject: NotRequired[str | None] min_size: NotRequired[int]
max_size: NotRequired[int]
min_confidences: NotRequired[dict[str, float]] min_confidences: NotRequired[dict[str, float]]
tags: NotRequired[list[str] | None]
observation_types: NotRequired[list[str] | None] observation_types: NotRequired[list[str] | None]
source_ids: NotRequired[list[int] | None] source_ids: NotRequired[list[int] | None]
@ -115,7 +115,6 @@ def group_chunks(
if isinstance(contents, dict): if isinstance(contents, dict):
tags = contents.pop("tags", []) tags = contents.pop("tags", [])
content = contents.pop("content", None) content = contents.pop("content", None)
print(content)
else: else:
content = contents content = contents
contents = {} contents = {}

View File

@ -4,6 +4,7 @@ from memory.common.db.models.source_item import (
SourceItem, SourceItem,
ConfidenceScore, ConfidenceScore,
clean_filename, clean_filename,
SourceItemPayload,
) )
from memory.common.db.models.source_items import ( from memory.common.db.models.source_items import (
MailMessage, MailMessage,
@ -19,6 +20,14 @@ from memory.common.db.models.source_items import (
Photo, Photo,
MiscDoc, MiscDoc,
Note, Note,
MailMessagePayload,
EmailAttachmentPayload,
AgentObservationPayload,
BlogPostPayload,
ComicPayload,
BookSectionPayload,
NotePayload,
ForumPostPayload,
) )
from memory.common.db.models.observations import ( from memory.common.db.models.observations import (
ObservationContradiction, ObservationContradiction,
@ -40,6 +49,18 @@ from memory.common.db.models.users import (
OAuthRefreshToken, OAuthRefreshToken,
) )
Payload = (
SourceItemPayload
| AgentObservationPayload
| NotePayload
| BlogPostPayload
| ComicPayload
| BookSectionPayload
| ForumPostPayload
| EmailAttachmentPayload
| MailMessagePayload
)
__all__ = [ __all__ = [
"Base", "Base",
"Chunk", "Chunk",
@ -75,4 +96,6 @@ __all__ = [
"OAuthClientInformation", "OAuthClientInformation",
"OAuthState", "OAuthState",
"OAuthRefreshToken", "OAuthRefreshToken",
# Payloads
"Payload",
] ]

View File

@ -4,7 +4,7 @@ Database models for the knowledge base system.
import pathlib import pathlib
import re import re
from typing import Any, Sequence, cast from typing import Any, Annotated, Sequence, TypedDict, cast
import uuid import uuid
from PIL import Image from PIL import Image
@ -36,6 +36,17 @@ import memory.common.summarizer as summarizer
from memory.common.db.models.base import Base from memory.common.db.models.base import Base
class MetadataSchema(TypedDict):
type: str
description: str
class SourceItemPayload(TypedDict):
source_id: Annotated[int, "Unique identifier of the source item"]
tags: Annotated[list[str], "List of tags for categorization"]
size: Annotated[int | None, "Size of the content in bytes"]
@event.listens_for(Session, "before_flush") @event.listens_for(Session, "before_flush")
def handle_duplicate_sha256(session, flush_context, instances): def handle_duplicate_sha256(session, flush_context, instances):
""" """
@ -344,12 +355,17 @@ class SourceItem(Base):
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
return [self._make_chunk(data, metadata) for data in self._chunk_contents()] return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
def as_payload(self) -> dict: def as_payload(self) -> SourceItemPayload:
return { return SourceItemPayload(
"source_id": self.id, source_id=cast(int, self.id),
"tags": self.tags, tags=cast(list[str], self.tags),
"size": self.size, size=cast(int | None, self.size),
} )
@classmethod
def get_collections(cls) -> list[str]:
"""Return the list of Qdrant collections this SourceItem type can be stored in."""
return [cls.__tablename__]
@property @property
def display_contents(self) -> str | dict | None: def display_contents(self) -> str | dict | None:

View File

@ -5,7 +5,7 @@ Database models for the knowledge base system.
import pathlib import pathlib
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, Sequence, cast from typing import Any, Annotated, Sequence, cast
from PIL import Image from PIL import Image
from sqlalchemy import ( from sqlalchemy import (
@ -32,11 +32,21 @@ import memory.common.formatters.observation as observation
from memory.common.db.models.source_item import ( from memory.common.db.models.source_item import (
SourceItem, SourceItem,
Chunk, Chunk,
SourceItemPayload,
clean_filename, clean_filename,
chunk_mixed, chunk_mixed,
) )
class MailMessagePayload(SourceItemPayload):
message_id: Annotated[str, "Unique email message identifier"]
subject: Annotated[str, "Email subject line"]
sender: Annotated[str, "Email sender address"]
recipients: Annotated[list[str], "List of recipient email addresses"]
folder: Annotated[str, "Email folder name"]
date: Annotated[str | None, "Email sent date in ISO format"]
class MailMessage(SourceItem): class MailMessage(SourceItem):
__tablename__ = "mail_message" __tablename__ = "mail_message"
@ -80,17 +90,21 @@ class MailMessage(SourceItem):
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
return path return path
def as_payload(self) -> dict: def as_payload(self) -> MailMessagePayload:
return { base_payload = super().as_payload() | {
**super().as_payload(), "tags": cast(list[str], self.tags)
"message_id": self.message_id, + [cast(str, self.sender)]
"subject": self.subject, + cast(list[str], self.recipients)
"sender": self.sender,
"recipients": self.recipients,
"folder": self.folder,
"tags": self.tags + [self.sender] + self.recipients,
"date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
} }
return MailMessagePayload(
**cast(dict, base_payload),
message_id=cast(str, self.message_id),
subject=cast(str, self.subject),
sender=cast(str, self.sender),
recipients=cast(list[str], self.recipients),
folder=cast(str, self.folder),
date=(self.sent_at and self.sent_at.isoformat() or None), # type: ignore
)
@property @property
def parsed_content(self) -> dict[str, Any]: def parsed_content(self) -> dict[str, Any]:
@ -152,7 +166,7 @@ class MailMessage(SourceItem):
def _chunk_contents(self) -> Sequence[extract.DataChunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
content = self.parsed_content content = self.parsed_content
chunks = extract.extract_text(cast(str, self.body)) chunks = extract.extract_text(cast(str, self.body), modality="email")
def add_header(item: extract.MulitmodalChunk) -> extract.MulitmodalChunk: def add_header(item: extract.MulitmodalChunk) -> extract.MulitmodalChunk:
if isinstance(item, str): if isinstance(item, str):
@ -163,6 +177,10 @@ class MailMessage(SourceItem):
chunk.data = [add_header(item) for item in chunk.data] chunk.data = [add_header(item) for item in chunk.data]
return chunks return chunks
@classmethod
def get_collections(cls) -> list[str]:
return ["mail"]
# Add indexes # Add indexes
__table_args__ = ( __table_args__ = (
Index("mail_sent_idx", "sent_at"), Index("mail_sent_idx", "sent_at"),
@ -171,6 +189,13 @@ class MailMessage(SourceItem):
) )
class EmailAttachmentPayload(SourceItemPayload):
filename: Annotated[str, "Name of the document file"]
content_type: Annotated[str, "MIME type of the document"]
mail_message_id: Annotated[int, "Associated email message ID"]
sent_at: Annotated[str | None, "Document creation timestamp"]
class EmailAttachment(SourceItem): class EmailAttachment(SourceItem):
__tablename__ = "email_attachment" __tablename__ = "email_attachment"
@ -190,17 +215,20 @@ class EmailAttachment(SourceItem):
"polymorphic_identity": "email_attachment", "polymorphic_identity": "email_attachment",
} }
def as_payload(self) -> dict: def as_payload(self) -> EmailAttachmentPayload:
return { return EmailAttachmentPayload(
**super().as_payload(), **super().as_payload(),
"filename": self.filename, filename=cast(str, self.filename),
"content_type": self.mime_type, content_type=cast(str, self.mime_type),
"size": self.size, mail_message_id=cast(int, self.mail_message_id),
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore sent_at=(
"mail_message_id": self.mail_message_id, self.mail_message.sent_at
} and self.mail_message.sent_at.isoformat()
or None
), # type: ignore
)
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
if cast(str | None, self.filename): if cast(str | None, self.filename):
contents = ( contents = (
settings.FILE_STORAGE_DIR / cast(str, self.filename) settings.FILE_STORAGE_DIR / cast(str, self.filename)
@ -208,8 +236,7 @@ class EmailAttachment(SourceItem):
else: else:
contents = cast(str, self.content) contents = cast(str, self.content)
chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents) return extract.extract_data_chunks(cast(str, self.mime_type), contents)
return [self._make_chunk(c, metadata) for c in chunks]
@property @property
def display_contents(self) -> dict: def display_contents(self) -> dict:
@ -221,6 +248,11 @@ class EmailAttachment(SourceItem):
# Add indexes # Add indexes
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),) __table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
@classmethod
def get_collections(cls) -> list[str]:
"""EmailAttachment can go to different collections based on mime_type"""
return ["doc", "text", "blog", "photo", "book"]
class ChatMessage(SourceItem): class ChatMessage(SourceItem):
__tablename__ = "chat_message" __tablename__ = "chat_message"
@ -285,6 +317,16 @@ class Photo(SourceItem):
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),) __table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
class ComicPayload(SourceItemPayload):
title: Annotated[str, "Title of the comic"]
author: Annotated[str | None, "Author of the comic"]
published: Annotated[str | None, "Publication date in ISO format"]
volume: Annotated[str | None, "Volume number"]
issue: Annotated[str | None, "Issue number"]
page: Annotated[int | None, "Page number"]
url: Annotated[str | None, "URL of the comic"]
class Comic(SourceItem): class Comic(SourceItem):
__tablename__ = "comic" __tablename__ = "comic"
@ -305,18 +347,17 @@ class Comic(SourceItem):
__table_args__ = (Index("comic_author_idx", "author"),) __table_args__ = (Index("comic_author_idx", "author"),)
def as_payload(self) -> dict: def as_payload(self) -> ComicPayload:
payload = { return ComicPayload(
**super().as_payload(), **super().as_payload(),
"title": self.title, title=cast(str, self.title),
"author": self.author, author=cast(str | None, self.author),
"published": self.published, published=(self.published and self.published.isoformat() or None), # type: ignore
"volume": self.volume, volume=cast(str | None, self.volume),
"issue": self.issue, issue=cast(str | None, self.issue),
"page": self.page, page=cast(int | None, self.page),
"url": self.url, url=cast(str | None, self.url),
} )
return {k: v for k, v in payload.items() if v is not None}
def _chunk_contents(self) -> Sequence[extract.DataChunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename)) image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename))
@ -324,6 +365,17 @@ class Comic(SourceItem):
return [extract.DataChunk(data=[image, description])] return [extract.DataChunk(data=[image, description])]
class BookSectionPayload(SourceItemPayload):
title: Annotated[str, "Title of the book"]
author: Annotated[str | None, "Author of the book"]
book_id: Annotated[int, "Unique identifier of the book"]
section_title: Annotated[str, "Title of the section"]
section_number: Annotated[int, "Number of the section"]
section_level: Annotated[int, "Level of the section"]
start_page: Annotated[int, "Starting page number"]
end_page: Annotated[int, "Ending page number"]
class BookSection(SourceItem): class BookSection(SourceItem):
"""Individual sections/chapters of books""" """Individual sections/chapters of books"""
@ -361,19 +413,22 @@ class BookSection(SourceItem):
Index("book_section_level_idx", "section_level", "section_number"), Index("book_section_level_idx", "section_level", "section_number"),
) )
def as_payload(self) -> dict: @classmethod
vals = { def get_collections(cls) -> list[str]:
return ["book"]
def as_payload(self) -> BookSectionPayload:
return BookSectionPayload(
**super().as_payload(), **super().as_payload(),
"title": self.book.title, title=cast(str, self.book.title),
"author": self.book.author, author=cast(str | None, self.book.author),
"book_id": self.book_id, book_id=cast(int, self.book_id),
"section_title": self.section_title, section_title=cast(str, self.section_title),
"section_number": self.section_number, section_number=cast(int, self.section_number),
"section_level": self.section_level, section_level=cast(int, self.section_level),
"start_page": self.start_page, start_page=cast(int, self.start_page),
"end_page": self.end_page, end_page=cast(int, self.end_page),
} )
return {k: v for k, v in vals.items() if v}
def _chunk_contents(self) -> Sequence[extract.DataChunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
content = cast(str, self.content.strip()) content = cast(str, self.content.strip())
@ -397,6 +452,16 @@ class BookSection(SourceItem):
] ]
class BlogPostPayload(SourceItemPayload):
url: Annotated[str, "URL of the blog post"]
title: Annotated[str, "Title of the blog post"]
author: Annotated[str | None, "Author of the blog post"]
published: Annotated[str | None, "Publication date in ISO format"]
description: Annotated[str | None, "Description of the blog post"]
domain: Annotated[str | None, "Domain of the blog post"]
word_count: Annotated[int | None, "Word count of the blog post"]
class BlogPost(SourceItem): class BlogPost(SourceItem):
__tablename__ = "blog_post" __tablename__ = "blog_post"
@ -428,27 +493,39 @@ class BlogPost(SourceItem):
Index("blog_post_word_count_idx", "word_count"), Index("blog_post_word_count_idx", "word_count"),
) )
def as_payload(self) -> dict: def as_payload(self) -> BlogPostPayload:
published_date = cast(datetime | None, self.published) published_date = cast(datetime | None, self.published)
metadata = cast(dict | None, self.webpage_metadata) or {} metadata = cast(dict | None, self.webpage_metadata) or {}
payload = { return BlogPostPayload(
**super().as_payload(), **super().as_payload(),
"url": self.url, url=cast(str, self.url),
"title": self.title, title=cast(str, self.title),
"author": self.author, author=cast(str | None, self.author),
"published": published_date and published_date.isoformat(), published=(published_date and published_date.isoformat() or None), # type: ignore
"description": self.description, description=cast(str | None, self.description),
"domain": self.domain, domain=cast(str | None, self.domain),
"word_count": self.word_count, word_count=cast(int | None, self.word_count),
**metadata, **metadata,
} )
return {k: v for k, v in payload.items() if v}
def _chunk_contents(self) -> Sequence[extract.DataChunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
return chunk_mixed(cast(str, self.content), cast(list[str], self.images)) return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
class ForumPostPayload(SourceItemPayload):
url: Annotated[str, "URL of the forum post"]
title: Annotated[str, "Title of the forum post"]
description: Annotated[str | None, "Description of the forum post"]
authors: Annotated[list[str] | None, "Authors of the forum post"]
published: Annotated[str | None, "Publication date in ISO format"]
slug: Annotated[str | None, "Slug of the forum post"]
karma: Annotated[int | None, "Karma score of the forum post"]
votes: Annotated[int | None, "Number of votes on the forum post"]
score: Annotated[int | None, "Score of the forum post"]
comments: Annotated[int | None, "Number of comments on the forum post"]
class ForumPost(SourceItem): class ForumPost(SourceItem):
__tablename__ = "forum_post" __tablename__ = "forum_post"
@ -479,20 +556,20 @@ class ForumPost(SourceItem):
Index("forum_post_title_idx", "title"), Index("forum_post_title_idx", "title"),
) )
def as_payload(self) -> dict: def as_payload(self) -> ForumPostPayload:
return { return ForumPostPayload(
**super().as_payload(), **super().as_payload(),
"url": self.url, url=cast(str, self.url),
"title": self.title, title=cast(str, self.title),
"description": self.description, description=cast(str | None, self.description),
"authors": self.authors, authors=cast(list[str] | None, self.authors),
"published_at": self.published_at, published=(self.published_at and self.published_at.isoformat() or None), # type: ignore
"slug": self.slug, slug=cast(str | None, self.slug),
"karma": self.karma, karma=cast(int | None, self.karma),
"votes": self.votes, votes=cast(int | None, self.votes),
"score": self.score, score=cast(int | None, self.score),
"comments": self.comments, comments=cast(int | None, self.comments),
} )
def _chunk_contents(self) -> Sequence[extract.DataChunk]: def _chunk_contents(self) -> Sequence[extract.DataChunk]:
return chunk_mixed(cast(str, self.content), cast(list[str], self.images)) return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
@ -545,6 +622,12 @@ class GithubItem(SourceItem):
) )
class NotePayload(SourceItemPayload):
note_type: Annotated[str | None, "Category of the note"]
subject: Annotated[str | None, "What the note is about"]
confidence: Annotated[dict[str, float], "Confidence scores for the note"]
class Note(SourceItem): class Note(SourceItem):
"""A quick note of something of interest.""" """A quick note of something of interest."""
@ -565,13 +648,13 @@ class Note(SourceItem):
Index("note_subject_idx", "subject"), Index("note_subject_idx", "subject"),
) )
def as_payload(self) -> dict: def as_payload(self) -> NotePayload:
return { return NotePayload(
**super().as_payload(), **super().as_payload(),
"note_type": self.note_type, note_type=cast(str | None, self.note_type),
"subject": self.subject, subject=cast(str | None, self.subject),
"confidence": self.confidence_dict, confidence=self.confidence_dict,
} )
@property @property
def display_contents(self) -> dict: def display_contents(self) -> dict:
@ -602,6 +685,19 @@ class Note(SourceItem):
self.as_text(cast(str, self.content), cast(str | None, self.subject)) self.as_text(cast(str, self.content), cast(str | None, self.subject))
) )
@classmethod
def get_collections(cls) -> list[str]:
return ["text"] # Notes go to the text collection
class AgentObservationPayload(SourceItemPayload):
session_id: Annotated[str | None, "Session ID for the observation"]
observation_type: Annotated[str, "Type of observation"]
subject: Annotated[str, "What/who the observation is about"]
confidence: Annotated[dict[str, float], "Confidence scores for the observation"]
evidence: Annotated[dict | None, "Supporting context, quotes, etc."]
agent_model: Annotated[str, "Which AI model made this observation"]
class AgentObservation(SourceItem): class AgentObservation(SourceItem):
""" """
@ -652,18 +748,16 @@ class AgentObservation(SourceItem):
kwargs["modality"] = "observation" kwargs["modality"] = "observation"
super().__init__(**kwargs) super().__init__(**kwargs)
def as_payload(self) -> dict: def as_payload(self) -> AgentObservationPayload:
payload = { return AgentObservationPayload(
**super().as_payload(), **super().as_payload(),
"observation_type": self.observation_type, observation_type=cast(str, self.observation_type),
"subject": self.subject, subject=cast(str, self.subject),
"confidence": self.confidence_dict, confidence=self.confidence_dict,
"evidence": self.evidence, evidence=cast(dict | None, self.evidence),
"agent_model": self.agent_model, agent_model=cast(str, self.agent_model),
} session_id=cast(str | None, self.session_id) and str(self.session_id),
if self.session_id is not None: )
payload["session_id"] = str(self.session_id)
return payload
@property @property
def all_contradictions(self): def all_contradictions(self):
@ -759,3 +853,7 @@ class AgentObservation(SourceItem):
# )) # ))
return chunks return chunks
@classmethod
def get_collections(cls) -> list[str]:
return ["semantic", "temporal"]

View File

@ -53,7 +53,9 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
return image return image
def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]: def doc_to_images(
content: bytes | str | pathlib.Path, modality: str = "doc"
) -> list[DataChunk]:
with as_file(content) as file_path: with as_file(content) as file_path:
with pymupdf.open(file_path) as pdf: with pymupdf.open(file_path) as pdf:
return [ return [
@ -65,6 +67,7 @@ def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
"height": page.rect.height, "height": page.rect.height,
}, },
mime_type="image/jpeg", mime_type="image/jpeg",
modality=modality,
) )
for page in pdf.pages() for page in pdf.pages()
] ]
@ -122,6 +125,7 @@ def extract_text(
content: bytes | str | pathlib.Path, content: bytes | str | pathlib.Path,
chunk_size: int | None = None, chunk_size: int | None = None,
metadata: dict[str, Any] = {}, metadata: dict[str, Any] = {},
modality: str = "text",
) -> list[DataChunk]: ) -> list[DataChunk]:
if isinstance(content, pathlib.Path): if isinstance(content, pathlib.Path):
content = content.read_text() content = content.read_text()
@ -130,7 +134,7 @@ def extract_text(
content = cast(str, content) content = cast(str, content)
chunks = [ chunks = [
DataChunk(data=[c], modality="text", metadata=metadata) DataChunk(data=[c], modality=modality, metadata=metadata)
for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS) for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
] ]
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2: if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
@ -139,7 +143,7 @@ def extract_text(
DataChunk( DataChunk(
data=[summary], data=[summary],
metadata=merge_metadata(metadata, {"tags": tags}), metadata=merge_metadata(metadata, {"tags": tags}),
modality="text", modality=modality,
) )
) )
return chunks return chunks
@ -158,9 +162,7 @@ def extract_data_chunks(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/msword", "application/msword",
]: ]:
logger.info(f"Extracting content from {content}")
chunks = extract_docx(content) chunks = extract_docx(content)
logger.info(f"Extracted {len(chunks)} pages from {content}")
elif mime_type.startswith("text/"): elif mime_type.startswith("text/"):
chunks = extract_text(content, chunk_size) chunks = extract_text(content, chunk_size)
elif mime_type.startswith("image/"): elif mime_type.startswith("image/"):

View File

@ -224,6 +224,15 @@ def get_collection_info(
return info.model_dump() return info.model_dump()
def get_collection_sizes(client: qdrant_client.QdrantClient) -> dict[str, int]:
"""Get the size of each collection."""
collections = [i.name for i in client.get_collections().collections]
return {
collection_name: client.count(collection_name).count # type: ignore
for collection_name in collections
}
def batch_ids( def batch_ids(
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000 client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
) -> Generator[list[str], None, None]: ) -> Generator[list[str], None, None]:

View File

@ -6,6 +6,8 @@ load_dotenv()
def boolean_env(key: str, default: bool = False) -> bool: def boolean_env(key: str, default: bool = False) -> bool:
if key not in os.environ:
return default
return os.getenv(key, "0").lower() in ("1", "true", "yes") return os.getenv(key, "0").lower() in ("1", "true", "yes")
@ -130,6 +132,10 @@ else:
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307") SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
# Search settings
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
# API settings # API settings
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000") SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
HTTPS = boolean_env("HTTPS", False) HTTPS = boolean_env("HTTPS", False)

View File

@ -0,0 +1,174 @@
import pytest
from memory.api.search.embeddings import merge_range_filter, merge_filters
def test_merge_range_filter_new_filter():
"""Test adding new range filters"""
filters = []
result = merge_range_filter(filters, "min_size", 100)
assert result == [{"key": "size", "range": {"gte": 100}}]
filters = []
result = merge_range_filter(filters, "max_size", 1000)
assert result == [{"key": "size", "range": {"lte": 1000}}]
def test_merge_range_filter_existing_field():
"""Test adding to existing field"""
filters = [{"key": "size", "range": {"lte": 1000}}]
result = merge_range_filter(filters, "min_size", 100)
assert result == [{"key": "size", "range": {"lte": 1000, "gte": 100}}]
def test_merge_range_filter_override_existing():
"""Test overriding existing values"""
filters = [{"key": "size", "range": {"gte": 100}}]
result = merge_range_filter(filters, "min_size", 200)
assert result == [{"key": "size", "range": {"gte": 200}}]
def test_merge_range_filter_with_other_filters():
"""Test adding range filter alongside other filters"""
filters = [{"key": "tags", "match": {"any": ["tag1"]}}]
result = merge_range_filter(filters, "min_size", 100)
expected = [
{"key": "tags", "match": {"any": ["tag1"]}},
{"key": "size", "range": {"gte": 100}},
]
assert result == expected
@pytest.mark.parametrize(
"key,expected_direction,expected_field",
[
("min_sent_at", "min", "sent_at"),
("max_sent_at", "max", "sent_at"),
("min_published", "min", "published"),
("max_published", "max", "published"),
("min_size", "min", "size"),
("max_size", "max", "size"),
],
)
def test_merge_range_filter_key_parsing(key, expected_direction, expected_field):
"""Test that field names are correctly extracted from keys"""
filters = []
result = merge_range_filter(filters, key, 100)
assert len(result) == 1
assert result[0]["key"] == expected_field
range_key = "gte" if expected_direction == "min" else "lte"
assert result[0]["range"][range_key] == 100
@pytest.mark.parametrize(
"filter_key,filter_value",
[
("tags", ["tag1", "tag2"]),
("recipients", ["user1", "user2"]),
("observation_types", ["belief", "preference"]),
("authors", ["author1"]),
],
)
def test_merge_filters_list_filters(filter_key, filter_value):
"""Test list filters that use match any"""
filters = []
result = merge_filters(filters, filter_key, filter_value)
expected = [{"key": filter_key, "match": {"any": filter_value}}]
assert result == expected
def test_merge_filters_min_confidences():
"""Test min_confidences filter creates multiple range conditions"""
filters = []
confidences = {"observation_accuracy": 0.8, "source_reliability": 0.9}
result = merge_filters(filters, "min_confidences", confidences)
expected = [
{"key": "confidence.observation_accuracy", "range": {"gte": 0.8}},
{"key": "confidence.source_reliability", "range": {"gte": 0.9}},
]
assert result == expected
def test_merge_filters_source_ids():
"""Test source_ids filter maps to id field"""
filters = []
result = merge_filters(filters, "source_ids", ["id1", "id2"])
expected = [{"key": "id", "match": {"any": ["id1", "id2"]}}]
assert result == expected
def test_merge_filters_range_delegation():
"""Test range filters are properly delegated to merge_range_filter"""
filters = []
result = merge_filters(filters, "min_size", 100)
assert len(result) == 1
assert "range" in result[0]
assert result[0]["range"]["gte"] == 100
def test_merge_filters_combined_range():
"""Test that min/max range pairs merge into single filter"""
filters = []
filters = merge_filters(filters, "min_size", 100)
filters = merge_filters(filters, "max_size", 1000)
size_filters = [f for f in filters if f["key"] == "size"]
assert len(size_filters) == 1
assert size_filters[0]["range"]["gte"] == 100
assert size_filters[0]["range"]["lte"] == 1000
def test_merge_filters_preserves_existing():
"""Test that existing filters are preserved when adding new ones"""
existing_filters = [{"key": "existing", "match": "value"}]
result = merge_filters(existing_filters, "tags", ["new_tag"])
assert len(result) == 2
assert {"key": "existing", "match": "value"} in result
assert {"key": "tags", "match": {"any": ["new_tag"]}} in result
def test_merge_filters_realistic_combination():
"""Test a realistic filter combination for knowledge base search"""
filters = []
# Add typical knowledge base filters
filters = merge_filters(filters, "tags", ["important", "work"])
filters = merge_filters(filters, "min_published", "2023-01-01")
filters = merge_filters(filters, "max_size", 1000000) # 1MB max
filters = merge_filters(filters, "min_confidences", {"observation_accuracy": 0.8})
assert len(filters) == 4
# Check each filter type
tag_filter = next(f for f in filters if f["key"] == "tags")
assert tag_filter["match"]["any"] == ["important", "work"]
published_filter = next(f for f in filters if f["key"] == "published")
assert published_filter["range"]["gte"] == "2023-01-01"
size_filter = next(f for f in filters if f["key"] == "size")
assert size_filter["range"]["lte"] == 1000000
confidence_filter = next(
f for f in filters if f["key"] == "confidence.observation_accuracy"
)
assert confidence_filter["range"]["gte"] == 0.8
def test_merge_filters_unknown_key():
"""Test fallback behavior for unknown filter keys"""
filters = []
result = merge_filters(filters, "unknown_field", "unknown_value")
expected = [{"key": "unknown_field", "match": "unknown_value"}]
assert result == expected
def test_merge_filters_empty_min_confidences():
"""Test min_confidences with empty dict does nothing"""
filters = []
result = merge_filters(filters, "min_confidences", {})
assert result == []