mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 15:14:45 +02:00
search filters
This commit is contained in:
parent
780e27ba04
commit
3e4e5872d1
@ -285,6 +285,7 @@ body {
|
||||
display: flex;
|
||||
gap: 1rem;
|
||||
align-items: center;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.search-input {
|
||||
@ -323,6 +324,34 @@ body {
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.search-options {
|
||||
border-top: 1px solid #e2e8f0;
|
||||
padding-top: 1.5rem;
|
||||
display: grid;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.search-option {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.search-option label {
|
||||
font-weight: 500;
|
||||
color: #4a5568;
|
||||
font-size: 0.9rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.search-option input[type="checkbox"] {
|
||||
margin-right: 0.5rem;
|
||||
transform: scale(1.1);
|
||||
accent-color: #667eea;
|
||||
}
|
||||
|
||||
/* Search Results */
|
||||
.search-results {
|
||||
margin-top: 2rem;
|
||||
@ -429,6 +458,11 @@ body {
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.tag.selected {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.tag:hover {
|
||||
background: #e2e8f0;
|
||||
}
|
||||
@ -495,6 +529,14 @@ body {
|
||||
padding: 1.5rem;
|
||||
}
|
||||
|
||||
.search-options {
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.limit-input {
|
||||
width: 80px;
|
||||
}
|
||||
|
||||
.search-result-card {
|
||||
padding: 1rem;
|
||||
}
|
||||
@ -689,4 +731,194 @@ body {
|
||||
|
||||
.markdown-preview a:hover {
|
||||
color: #2b6cb0;
|
||||
}
|
||||
|
||||
/* Dynamic filters styles */
|
||||
.modality-filters {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.modality-filter-group {
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 8px;
|
||||
background: white;
|
||||
}
|
||||
|
||||
.modality-filter-title {
|
||||
padding: 0.75rem 1rem;
|
||||
font-weight: 600;
|
||||
color: #4a5568;
|
||||
background: #f8fafc;
|
||||
border-radius: 7px 7px 0 0;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.modality-filter-title:hover {
|
||||
background: #edf2f7;
|
||||
}
|
||||
|
||||
.modality-filter-group[open] .modality-filter-title {
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
border-radius: 7px 7px 0 0;
|
||||
}
|
||||
|
||||
.filters-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 1rem;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
.filter-field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.filter-label {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
color: #4a5568;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.filter-input {
|
||||
padding: 0.5rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
background: white;
|
||||
transition: border-color 0.2s;
|
||||
}
|
||||
|
||||
.filter-input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
box-shadow: 0 0 0 1px #667eea;
|
||||
}
|
||||
|
||||
/* Selectable tags controls */
|
||||
.selectable-tags-details {
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 8px;
|
||||
background: white;
|
||||
}
|
||||
|
||||
.selectable-tags-summary {
|
||||
padding: 0.75rem 1rem;
|
||||
font-weight: 600;
|
||||
color: #4a5568;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s;
|
||||
border-radius: 7px;
|
||||
}
|
||||
|
||||
.selectable-tags-summary:hover {
|
||||
background: #f8fafc;
|
||||
}
|
||||
|
||||
.selectable-tags-details[open] .selectable-tags-summary {
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
border-radius: 7px 7px 0 0;
|
||||
background: #f8fafc;
|
||||
}
|
||||
|
||||
.tag-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.75rem 1rem;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
background: #fafbfc;
|
||||
}
|
||||
|
||||
.tag-control-btn {
|
||||
background: #f7fafc;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 4px;
|
||||
padding: 0.25rem 0.5rem;
|
||||
font-size: 0.75rem;
|
||||
color: #4a5568;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.tag-control-btn:hover:not(:disabled) {
|
||||
background: #edf2f7;
|
||||
border-color: #cbd5e0;
|
||||
}
|
||||
|
||||
.tag-control-btn:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.tag-count {
|
||||
font-size: 0.75rem;
|
||||
color: #718096;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* Tag search controls */
|
||||
.tag-search-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
padding: 0.75rem 1rem;
|
||||
background: #f8fafc;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
}
|
||||
|
||||
.tag-search-input {
|
||||
flex: 1;
|
||||
padding: 0.375rem 0.5rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
background: white;
|
||||
transition: border-color 0.2s;
|
||||
}
|
||||
|
||||
.tag-search-input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
box-shadow: 0 0 0 1px #667eea;
|
||||
}
|
||||
|
||||
|
||||
|
||||
.filtered-count {
|
||||
font-size: 0.75rem;
|
||||
color: #718096;
|
||||
font-weight: 500;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.tags-display-area {
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.filters-grid {
|
||||
grid-template-columns: 1fr;
|
||||
gap: 0.5rem;
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.modality-filter-title {
|
||||
padding: 0.5rem 0.75rem;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.tag-search-controls {
|
||||
flex-direction: column;
|
||||
align-items: stretch;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
}
|
@ -2,9 +2,9 @@ import { useEffect } from 'react'
|
||||
import { BrowserRouter as Router, Routes, Route, Navigate, useNavigate, useLocation } from 'react-router-dom'
|
||||
import './App.css'
|
||||
|
||||
import { useAuth } from './hooks/useAuth'
|
||||
import { useOAuth } from './hooks/useOAuth'
|
||||
import { Loading, LoginPrompt, AuthError, Dashboard, Search } from './components'
|
||||
import { useAuth } from '@/hooks/useAuth'
|
||||
import { useOAuth } from '@/hooks/useOAuth'
|
||||
import { Loading, LoginPrompt, AuthError, Dashboard, Search } from '@/components'
|
||||
|
||||
// AuthWrapper handles redirects based on auth state
|
||||
const AuthWrapper = () => {
|
||||
@ -31,7 +31,10 @@ const AuthWrapper = () => {
|
||||
// Handle redirects based on auth state changes
|
||||
useEffect(() => {
|
||||
if (!isLoading) {
|
||||
if (isAuthenticated) {
|
||||
if (location.pathname === '/ui/logout') {
|
||||
logout()
|
||||
navigate('/ui/login', { replace: true })
|
||||
} else if (isAuthenticated) {
|
||||
// If authenticated and on login page, redirect to dashboard
|
||||
if (location.pathname === '/ui/login' || location.pathname === '/ui') {
|
||||
navigate('/ui/dashboard', { replace: true })
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { Link } from 'react-router-dom'
|
||||
import { useMCP } from '../hooks/useMCP'
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
|
||||
const Dashboard = ({ onLogout }) => {
|
||||
const { listNotes } = useMCP()
|
||||
|
@ -1,6 +1,4 @@
|
||||
import React from 'react'
|
||||
|
||||
const Loading = ({ message = "Loading..." }) => {
|
||||
const Loading = ({ message = "Loading..." }: { message?: string }) => {
|
||||
return (
|
||||
<div className="loading">
|
||||
<h2>{message}</h2>
|
||||
|
@ -1,6 +1,4 @@
|
||||
import React from 'react'
|
||||
|
||||
const AuthError = ({ error, onRetry }) => {
|
||||
const AuthError = ({ error, onRetry }: { error: string, onRetry: () => void }) => {
|
||||
return (
|
||||
<div className="error">
|
||||
<h2>Authentication Error</h2>
|
||||
|
@ -1,6 +1,4 @@
|
||||
import React from 'react'
|
||||
|
||||
const LoginPrompt = ({ onLogin }) => {
|
||||
const LoginPrompt = ({ onLogin }: { onLogin: () => void }) => {
|
||||
return (
|
||||
<div className="login-prompt">
|
||||
<h1>Memory App</h1>
|
||||
|
@ -1,5 +1,5 @@
|
||||
export { default as Loading } from './Loading'
|
||||
export { default as Dashboard } from './Dashboard'
|
||||
export { default as Search } from './Search'
|
||||
export { default as Search } from './search'
|
||||
export { default as LoginPrompt } from './auth/LoginPrompt'
|
||||
export { default as AuthError } from './auth/AuthError'
|
155
frontend/src/components/search/DynamicFilters.tsx
Normal file
155
frontend/src/components/search/DynamicFilters.tsx
Normal file
@ -0,0 +1,155 @@
|
||||
import { FilterInput } from './FilterInput'
|
||||
import { CollectionMetadata, SchemaArg } from '@/types/mcp'
|
||||
|
||||
// Pure helper functions for schema processing
|
||||
const formatFieldLabel = (field: string): string =>
|
||||
field.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase())
|
||||
|
||||
const shouldSkipField = (fieldName: string): boolean =>
|
||||
['tags'].includes(fieldName)
|
||||
|
||||
const isCommonField = (fieldName: string): boolean =>
|
||||
['size', 'filename', 'content_type', 'mail_message_id', 'sent_at', 'created_at', 'source_id'].includes(fieldName)
|
||||
|
||||
const createMinMaxFields = (fieldName: string, fieldConfig: any): [string, any][] => [
|
||||
[`min_${fieldName}`, { ...fieldConfig, description: `Min ${fieldConfig.description}` }],
|
||||
[`max_${fieldName}`, { ...fieldConfig, description: `Max ${fieldConfig.description}` }]
|
||||
]
|
||||
|
||||
const createSizeFields = (): [string, any][] => [
|
||||
['min_size', { type: 'int', description: 'Minimum size in bytes' }],
|
||||
['max_size', { type: 'int', description: 'Maximum size in bytes' }]
|
||||
]
|
||||
|
||||
const extractSchemaFields = (schema: Record<string, SchemaArg>, includeCommon = true): [string, SchemaArg][] =>
|
||||
Object.entries(schema)
|
||||
.filter(([fieldName]) => !shouldSkipField(fieldName))
|
||||
.filter(([fieldName]) => includeCommon || !isCommonField(fieldName))
|
||||
.flatMap(([fieldName, fieldConfig]) =>
|
||||
['sent_at', 'published', 'created_at'].includes(fieldName)
|
||||
? createMinMaxFields(fieldName, fieldConfig)
|
||||
: [[fieldName, fieldConfig] as [string, SchemaArg]]
|
||||
)
|
||||
|
||||
const getCommonFields = (schemas: Record<string, CollectionMetadata>, selectedModalities: string[]): [string, SchemaArg][] => {
|
||||
const commonFieldsMap = new Map<string, SchemaArg>()
|
||||
|
||||
// Manually add created_at fields even if not in schema
|
||||
createMinMaxFields('created_at', { type: 'datetime', description: 'Creation date' }).forEach(([field, config]) => {
|
||||
commonFieldsMap.set(field, config)
|
||||
})
|
||||
|
||||
selectedModalities.forEach(modality => {
|
||||
const schema = schemas[modality].schema
|
||||
if (!schema) return
|
||||
|
||||
Object.entries(schema).forEach(([fieldName, fieldConfig]) => {
|
||||
if (isCommonField(fieldName)) {
|
||||
if (['sent_at', 'created_at'].includes(fieldName)) {
|
||||
createMinMaxFields(fieldName, fieldConfig).forEach(([field, config]) => {
|
||||
commonFieldsMap.set(field, config)
|
||||
})
|
||||
} else if (fieldName === 'size') {
|
||||
createSizeFields().forEach(([field, config]) => {
|
||||
commonFieldsMap.set(field, config)
|
||||
})
|
||||
} else {
|
||||
commonFieldsMap.set(fieldName, fieldConfig)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return Array.from(commonFieldsMap.entries())
|
||||
}
|
||||
|
||||
const getModalityFields = (schemas: Record<string, CollectionMetadata>, selectedModalities: string[]): Record<string, [string, SchemaArg][]> => {
|
||||
return selectedModalities.reduce((acc, modality) => {
|
||||
const schema = schemas[modality].schema
|
||||
if (!schema) return acc
|
||||
|
||||
const schemaFields = extractSchemaFields(schema, false) // Exclude common fields
|
||||
|
||||
if (schemaFields.length > 0) {
|
||||
acc[modality] = schemaFields
|
||||
}
|
||||
|
||||
return acc
|
||||
}, {} as Record<string, [string, SchemaArg][]>)
|
||||
}
|
||||
|
||||
interface DynamicFiltersProps {
|
||||
schemas: Record<string, CollectionMetadata>
|
||||
selectedModalities: string[]
|
||||
filters: Record<string, SchemaArg>
|
||||
onFilterChange: (field: string, value: SchemaArg) => void
|
||||
}
|
||||
|
||||
export const DynamicFilters = ({
|
||||
schemas,
|
||||
selectedModalities,
|
||||
filters,
|
||||
onFilterChange
|
||||
}: DynamicFiltersProps) => {
|
||||
const commonFields = getCommonFields(schemas, selectedModalities)
|
||||
const modalityFields = getModalityFields(schemas, selectedModalities)
|
||||
|
||||
if (commonFields.length === 0 && Object.keys(modalityFields).length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="search-option">
|
||||
<label>Filters:</label>
|
||||
<div className="modality-filters">
|
||||
{/* Common/Document Properties Section */}
|
||||
{commonFields.length > 0 && (
|
||||
<details className="modality-filter-group" open>
|
||||
<summary className="modality-filter-title">
|
||||
Document Properties
|
||||
</summary>
|
||||
<div className="filters-grid">
|
||||
{commonFields.map(([field, fieldConfig]: [string, SchemaArg]) => (
|
||||
<div key={field} className="filter-field">
|
||||
<label className="filter-label">
|
||||
{formatFieldLabel(field)}:
|
||||
</label>
|
||||
<FilterInput
|
||||
field={field}
|
||||
fieldConfig={fieldConfig}
|
||||
value={filters[field]}
|
||||
onChange={onFilterChange}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
)}
|
||||
|
||||
{/* Modality-specific sections */}
|
||||
{Object.entries(modalityFields).map(([modality, fields]) => (
|
||||
<details key={modality} className="modality-filter-group">
|
||||
<summary className="modality-filter-title">
|
||||
{formatFieldLabel(modality)} Specific
|
||||
</summary>
|
||||
<div className="filters-grid">
|
||||
{fields.map(([field, fieldConfig]: [string, SchemaArg]) => (
|
||||
<div key={field} className="filter-field">
|
||||
<label className="filter-label">
|
||||
{formatFieldLabel(field)}:
|
||||
</label>
|
||||
<FilterInput
|
||||
field={field}
|
||||
fieldConfig={fieldConfig}
|
||||
value={filters[field]}
|
||||
onChange={onFilterChange}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
53
frontend/src/components/search/FilterInput.tsx
Normal file
53
frontend/src/components/search/FilterInput.tsx
Normal file
@ -0,0 +1,53 @@
|
||||
// Pure helper functions
|
||||
const isDateField = (field: string): boolean =>
|
||||
field.includes('sent_at') || field.includes('published') || field.includes('created_at')
|
||||
|
||||
const isNumberField = (field: string, fieldConfig: any): boolean =>
|
||||
fieldConfig.type?.includes('int') || field.includes('size')
|
||||
|
||||
const parseNumberValue = (value: string): number | null =>
|
||||
value ? parseInt(value) : null
|
||||
|
||||
interface FilterInputProps {
|
||||
field: string
|
||||
fieldConfig: any
|
||||
value: any
|
||||
onChange: (field: string, value: any) => void
|
||||
}
|
||||
|
||||
export const FilterInput = ({ field, fieldConfig, value, onChange }: FilterInputProps) => {
|
||||
const inputProps = {
|
||||
value: value || '',
|
||||
className: "filter-input"
|
||||
}
|
||||
|
||||
if (isNumberField(field, fieldConfig)) {
|
||||
return (
|
||||
<input
|
||||
{...inputProps}
|
||||
type="number"
|
||||
onChange={(e) => onChange(field, parseNumberValue(e.target.value))}
|
||||
placeholder={fieldConfig.description}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (isDateField(field)) {
|
||||
return (
|
||||
<input
|
||||
{...inputProps}
|
||||
type="date"
|
||||
onChange={(e) => onChange(field, e.target.value || null)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<input
|
||||
{...inputProps}
|
||||
type="text"
|
||||
onChange={(e) => onChange(field, e.target.value || null)}
|
||||
placeholder={fieldConfig.description}
|
||||
/>
|
||||
)
|
||||
}
|
69
frontend/src/components/search/Search.tsx
Normal file
69
frontend/src/components/search/Search.tsx
Normal file
@ -0,0 +1,69 @@
|
||||
import { useState } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
import Loading from '@/components/Loading'
|
||||
import { SearchResult } from './results'
|
||||
import SearchForm, { SearchParams } from './SearchForm'
|
||||
|
||||
const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => {
|
||||
if (isLoading) {
|
||||
return <Loading message="Searching..." />
|
||||
}
|
||||
return (
|
||||
<div className="search-results">
|
||||
{results.length > 0 && (
|
||||
<div className="results-count">
|
||||
Found {results.length} result{results.length !== 1 ? 's' : ''}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{results.map((result, index) => <SearchResult key={index} result={result} />)}
|
||||
|
||||
{results.length === 0 && (
|
||||
<div className="no-results">
|
||||
No results found
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
const Search = () => {
|
||||
const navigate = useNavigate()
|
||||
const [results, setResults] = useState([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const { searchKnowledgeBase } = useMCP()
|
||||
|
||||
const handleSearch = async (params: SearchParams) => {
|
||||
if (!params.query.trim()) return
|
||||
|
||||
setIsLoading(true)
|
||||
try {
|
||||
console.log(params)
|
||||
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
|
||||
setResults(searchResults || [])
|
||||
} catch (error) {
|
||||
console.error('Search error:', error)
|
||||
setResults([])
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="search-view">
|
||||
<header className="search-header">
|
||||
<button onClick={() => navigate('/ui/dashboard')} className="back-btn">
|
||||
← Back to Dashboard
|
||||
</button>
|
||||
<h2>🔍 Search Knowledge Base</h2>
|
||||
</header>
|
||||
|
||||
<SearchForm isLoading={isLoading} onSearch={handleSearch} />
|
||||
<SearchResults results={results} isLoading={isLoading} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Search
|
151
frontend/src/components/search/SearchForm.tsx
Normal file
151
frontend/src/components/search/SearchForm.tsx
Normal file
@ -0,0 +1,151 @@
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { DynamicFilters } from './DynamicFilters'
|
||||
import { SelectableTags } from './SelectableTags'
|
||||
import { CollectionMetadata } from '@/types/mcp'
|
||||
|
||||
type Filter = {
|
||||
tags?: string[]
|
||||
source_ids?: string[]
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
export interface SearchParams {
|
||||
query: string
|
||||
previews: boolean
|
||||
modalities: string[]
|
||||
filters: Filter
|
||||
limit: number
|
||||
}
|
||||
|
||||
interface SearchFormProps {
|
||||
isLoading: boolean
|
||||
onSearch: (params: SearchParams) => void
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Pure helper functions for SearchForm
|
||||
const createFlags = (items: string[], defaultValue = false): Record<string, boolean> =>
|
||||
items.reduce((acc, item) => ({ ...acc, [item]: defaultValue }), {})
|
||||
|
||||
const getSelectedItems = (items: Record<string, boolean>): string[] =>
|
||||
Object.entries(items).filter(([_, selected]) => selected).map(([key]) => key)
|
||||
|
||||
const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
|
||||
Object.entries(filters)
|
||||
.filter(([_, value]) => value !== null && value !== '' && value !== undefined)
|
||||
.reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {})
|
||||
|
||||
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||
const [query, setQuery] = useState('')
|
||||
const [previews, setPreviews] = useState(false)
|
||||
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
||||
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
||||
const [tags, setTags] = useState<Record<string, boolean>>({})
|
||||
const [dynamicFilters, setDynamicFilters] = useState<Record<string, any>>({})
|
||||
const [limit, setLimit] = useState(10)
|
||||
const { getMetadataSchemas, getTags } = useMCP()
|
||||
|
||||
useEffect(() => {
|
||||
const setupFilters = async () => {
|
||||
const [schemas, tags] = await Promise.all([
|
||||
getMetadataSchemas(),
|
||||
getTags()
|
||||
])
|
||||
setSchemas(schemas)
|
||||
setModalities(createFlags(Object.keys(schemas), true))
|
||||
setTags(createFlags(tags))
|
||||
}
|
||||
setupFilters()
|
||||
}, [getMetadataSchemas, getTags])
|
||||
|
||||
const handleFilterChange = (field: string, value: any) =>
|
||||
setDynamicFilters(prev => ({ ...prev, [field]: value }))
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
|
||||
onSearch({
|
||||
query,
|
||||
previews,
|
||||
modalities: getSelectedItems(modalities),
|
||||
filters: {
|
||||
tags: getSelectedItems(tags),
|
||||
...cleanFilters(dynamicFilters)
|
||||
},
|
||||
limit
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className="search-form">
|
||||
<div className="search-input-group">
|
||||
<input
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
placeholder="Search your knowledge base..."
|
||||
className="search-input"
|
||||
required
|
||||
/>
|
||||
<button type="submit" disabled={isLoading} className="search-btn">
|
||||
{isLoading ? 'Searching...' : 'Search'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="search-options">
|
||||
<div className="search-option">
|
||||
<label>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={previews}
|
||||
onChange={(e) => setPreviews(e.target.checked)}
|
||||
/>
|
||||
Include content previews
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<SelectableTags
|
||||
title="Modalities"
|
||||
className="modality-checkboxes"
|
||||
tags={modalities}
|
||||
onSelect={(tag, selected) => setModalities({ ...modalities, [tag]: selected })}
|
||||
onBatchUpdate={(updates) => setModalities(updates)}
|
||||
/>
|
||||
|
||||
<SelectableTags
|
||||
title="Tags"
|
||||
className="tags-container"
|
||||
tags={tags}
|
||||
onSelect={(tag, selected) => setTags({ ...tags, [tag]: selected })}
|
||||
onBatchUpdate={(updates) => setTags(updates)}
|
||||
searchable={true}
|
||||
/>
|
||||
|
||||
<DynamicFilters
|
||||
schemas={schemas}
|
||||
selectedModalities={getSelectedItems(modalities)}
|
||||
filters={dynamicFilters}
|
||||
onFilterChange={handleFilterChange}
|
||||
/>
|
||||
|
||||
<div className="search-option">
|
||||
<label>
|
||||
Max Results:
|
||||
<input
|
||||
type="number"
|
||||
value={limit}
|
||||
onChange={(e) => setLimit(parseInt(e.target.value) || 10)}
|
||||
min={1}
|
||||
max={100}
|
||||
className="limit-input"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
export default SearchForm
|
136
frontend/src/components/search/SelectableTags.tsx
Normal file
136
frontend/src/components/search/SelectableTags.tsx
Normal file
@ -0,0 +1,136 @@
|
||||
interface SelectableTagProps {
|
||||
tag: string
|
||||
selected: boolean
|
||||
onSelect: (tag: string, selected: boolean) => void
|
||||
}
|
||||
|
||||
const SelectableTag = ({ tag, selected, onSelect }: SelectableTagProps) => {
|
||||
return (
|
||||
<span
|
||||
className={`tag ${selected ? 'selected' : ''}`}
|
||||
onClick={() => onSelect(tag, !selected)}
|
||||
>
|
||||
{tag}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
import { useState } from 'react'
|
||||
|
||||
interface SelectableTagsProps {
|
||||
title: string
|
||||
className: string
|
||||
tags: Record<string, boolean>
|
||||
onSelect: (tag: string, selected: boolean) => void
|
||||
onBatchUpdate?: (updates: Record<string, boolean>) => void
|
||||
searchable?: boolean
|
||||
}
|
||||
|
||||
export const SelectableTags = ({ title, className, tags, onSelect, onBatchUpdate, searchable = false }: SelectableTagsProps) => {
|
||||
const [searchTerm, setSearchTerm] = useState('')
|
||||
const handleSelectAll = () => {
|
||||
if (onBatchUpdate) {
|
||||
const updates = Object.keys(tags).reduce((acc, tag) => {
|
||||
acc[tag] = true
|
||||
return acc
|
||||
}, {} as Record<string, boolean>)
|
||||
onBatchUpdate(updates)
|
||||
} else {
|
||||
// Fallback to individual updates (though this won't work well)
|
||||
Object.keys(tags).forEach(tag => {
|
||||
if (!tags[tag]) {
|
||||
onSelect(tag, true)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleDeselectAll = () => {
|
||||
if (onBatchUpdate) {
|
||||
const updates = Object.keys(tags).reduce((acc, tag) => {
|
||||
acc[tag] = false
|
||||
return acc
|
||||
}, {} as Record<string, boolean>)
|
||||
onBatchUpdate(updates)
|
||||
} else {
|
||||
// Fallback to individual updates (though this won't work well)
|
||||
Object.keys(tags).forEach(tag => {
|
||||
if (tags[tag]) {
|
||||
onSelect(tag, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Filter tags based on search term
|
||||
const filteredTags = Object.entries(tags).filter(([tag, selected]) => {
|
||||
return !searchTerm || tag.toLowerCase().includes(searchTerm.toLowerCase())
|
||||
})
|
||||
|
||||
const selectedCount = Object.values(tags).filter(Boolean).length
|
||||
const totalCount = Object.keys(tags).length
|
||||
const filteredSelectedCount = filteredTags.filter(([_, selected]) => selected).length
|
||||
const filteredTotalCount = filteredTags.length
|
||||
const allSelected = selectedCount === totalCount
|
||||
const noneSelected = selectedCount === 0
|
||||
|
||||
return (
|
||||
<div className="search-option">
|
||||
<details className="selectable-tags-details">
|
||||
<summary className="selectable-tags-summary">
|
||||
{title} ({selectedCount} selected)
|
||||
</summary>
|
||||
|
||||
<div className="tag-controls">
|
||||
<button
|
||||
type="button"
|
||||
className="tag-control-btn"
|
||||
onClick={handleSelectAll}
|
||||
disabled={allSelected}
|
||||
>
|
||||
All
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
className="tag-control-btn"
|
||||
onClick={handleDeselectAll}
|
||||
disabled={noneSelected}
|
||||
>
|
||||
None
|
||||
</button>
|
||||
<span className="tag-count">
|
||||
({selectedCount}/{totalCount})
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{searchable && (
|
||||
<div className="tag-search-controls">
|
||||
<input
|
||||
type="text"
|
||||
placeholder={`Search ${title.toLowerCase()}...`}
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className="tag-search-input"
|
||||
/>
|
||||
{searchTerm && (
|
||||
<span className="filtered-count">
|
||||
Showing {filteredSelectedCount}/{filteredTotalCount}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className={`${className} tags-display-area`}>
|
||||
{filteredTags.map(([tag, selected]: [string, boolean]) => (
|
||||
<SelectableTag
|
||||
key={tag}
|
||||
tag={tag}
|
||||
selected={selected}
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
)
|
||||
}
|
3
frontend/src/components/search/index.js
Normal file
3
frontend/src/components/search/index.js
Normal file
@ -0,0 +1,3 @@
|
||||
import Search from './Search'
|
||||
|
||||
export default Search
|
@ -1,10 +1,8 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { useState, useEffect } from 'react'
|
||||
import ReactMarkdown from 'react-markdown'
|
||||
import { useMCP } from '../hooks/useMCP'
|
||||
import Loading from './Loading'
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
|
||||
type SearchItem = {
|
||||
export type SearchItem = {
|
||||
filename: string
|
||||
content: string
|
||||
chunks: any[]
|
||||
@ -13,7 +11,7 @@ type SearchItem = {
|
||||
metadata: any
|
||||
}
|
||||
|
||||
const Tag = ({ tags }: { tags: string[] }) => {
|
||||
export const Tag = ({ tags }: { tags: string[] }) => {
|
||||
return (
|
||||
<div className="tags">
|
||||
{tags?.map((tag: string, index: number) => (
|
||||
@ -23,11 +21,12 @@ const Tag = ({ tags }: { tags: string[] }) => {
|
||||
)
|
||||
}
|
||||
|
||||
const TextResult = ({ filename, content, chunks, tags }: SearchItem) => {
|
||||
export const TextResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{filename || 'Untitled'}</h4>
|
||||
<Tag tags={tags} />
|
||||
<Metadata metadata={metadata} />
|
||||
<p className="result-content">{content || 'No content available'}</p>
|
||||
{chunks && chunks.length > 0 && (
|
||||
<details className="result-chunks">
|
||||
@ -44,7 +43,7 @@ const TextResult = ({ filename, content, chunks, tags }: SearchItem) => {
|
||||
)
|
||||
}
|
||||
|
||||
const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
||||
export const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{filename || 'Untitled'}</h4>
|
||||
@ -70,7 +69,7 @@ const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchIte
|
||||
)
|
||||
}
|
||||
|
||||
const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => {
|
||||
export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
||||
const title = metadata?.title || filename || 'Untitled'
|
||||
const { fetchFile } = useMCP()
|
||||
const [mime_type, setMimeType] = useState<string>()
|
||||
@ -95,7 +94,7 @@ const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => {
|
||||
)
|
||||
}
|
||||
|
||||
const Metadata = ({ metadata }: { metadata: any }) => {
|
||||
export const Metadata = ({ metadata }: { metadata: any }) => {
|
||||
if (!metadata) return null
|
||||
return (
|
||||
<div className="metadata">
|
||||
@ -108,7 +107,7 @@ const Metadata = ({ metadata }: { metadata: any }) => {
|
||||
)
|
||||
}
|
||||
|
||||
const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
||||
export const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{filename || 'Untitled'}</h4>
|
||||
@ -125,7 +124,7 @@ const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
||||
)
|
||||
}
|
||||
|
||||
const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
||||
export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{metadata?.title || metadata?.subject || 'Untitled'}</h4>
|
||||
@ -138,7 +137,7 @@ const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
||||
)
|
||||
}
|
||||
|
||||
const SearchResult = ({ result }: { result: SearchItem }) => {
|
||||
export const SearchResult = ({ result }: { result: SearchItem }) => {
|
||||
if (result.mime_type.startsWith('image/')) {
|
||||
return <ImageResult {...result} />
|
||||
}
|
||||
@ -158,86 +157,4 @@ const SearchResult = ({ result }: { result: SearchItem }) => {
|
||||
return null
|
||||
}
|
||||
|
||||
const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => {
|
||||
if (isLoading) {
|
||||
return <Loading message="Searching..." />
|
||||
}
|
||||
return (
|
||||
<div className="search-results">
|
||||
{results.length > 0 && (
|
||||
<div className="results-count">
|
||||
Found {results.length} result{results.length !== 1 ? 's' : ''}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{results.map((result, index) => <SearchResult key={index} result={result} />)}
|
||||
|
||||
{results.length === 0 && (
|
||||
<div className="no-results">
|
||||
No results found
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const SearchForm = ({ isLoading, onSearch }: { isLoading: boolean, onSearch: (query: string) => void }) => {
|
||||
const [query, setQuery] = useState('')
|
||||
return (
|
||||
<form onSubmit={(e) => {
|
||||
e.preventDefault()
|
||||
onSearch(query)
|
||||
}} className="search-form">
|
||||
<div className="search-input-group">
|
||||
<input
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
placeholder="Search your knowledge base..."
|
||||
className="search-input"
|
||||
/>
|
||||
<button type="submit" disabled={isLoading} className="search-btn">
|
||||
{isLoading ? 'Searching...' : 'Search'}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
const Search = () => {
|
||||
const navigate = useNavigate()
|
||||
const [results, setResults] = useState([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const { searchKnowledgeBase } = useMCP()
|
||||
|
||||
const handleSearch = async (query: string) => {
|
||||
if (!query.trim()) return
|
||||
|
||||
setIsLoading(true)
|
||||
try {
|
||||
const searchResults = await searchKnowledgeBase(query)
|
||||
setResults(searchResults || [])
|
||||
} catch (error) {
|
||||
console.error('Search error:', error)
|
||||
setResults([])
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="search-view">
|
||||
<header className="search-header">
|
||||
<button onClick={() => navigate('/ui/dashboard')} className="back-btn">
|
||||
← Back to Dashboard
|
||||
</button>
|
||||
<h2>🔍 Search Knowledge Base</h2>
|
||||
</header>
|
||||
|
||||
<SearchForm isLoading={isLoading} onSearch={handleSearch} />
|
||||
<SearchResults results={results} isLoading={isLoading} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Search
|
||||
export default SearchResult
|
@ -4,20 +4,20 @@ const SERVER_URL = import.meta.env.VITE_SERVER_URL || 'http://localhost:8000'
|
||||
const SESSION_COOKIE_NAME = import.meta.env.VITE_SESSION_COOKIE_NAME || 'session_id'
|
||||
|
||||
// Cookie utilities
|
||||
const getCookie = (name) => {
|
||||
const getCookie = (name: string) => {
|
||||
const value = `; ${document.cookie}`
|
||||
const parts = value.split(`; ${name}=`)
|
||||
if (parts.length === 2) return parts.pop().split(';').shift()
|
||||
return null
|
||||
}
|
||||
|
||||
const setCookie = (name, value, days = 30) => {
|
||||
const setCookie = (name: string, value: string, days = 30) => {
|
||||
const expires = new Date()
|
||||
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
|
||||
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
|
||||
}
|
||||
|
||||
const deleteCookie = (name) => {
|
||||
const deleteCookie = (name: string) => {
|
||||
document.cookie = `${name}=;expires=Thu, 01 Jan 1970 00:00:01 GMT;path=/`
|
||||
}
|
||||
|
||||
@ -68,6 +68,7 @@ export const useAuth = () => {
|
||||
deleteCookie('access_token')
|
||||
deleteCookie('refresh_token')
|
||||
deleteCookie(SESSION_COOKIE_NAME)
|
||||
localStorage.removeItem('oauth_client_id')
|
||||
setIsAuthenticated(false)
|
||||
}, [])
|
||||
|
||||
@ -110,7 +111,7 @@ export const useAuth = () => {
|
||||
}, [logout])
|
||||
|
||||
// Make authenticated API calls with automatic token refresh
|
||||
const apiCall = useCallback(async (endpoint, options = {}) => {
|
||||
const apiCall = useCallback(async (endpoint: string, options: RequestInit = {}) => {
|
||||
let accessToken = getCookie('access_token')
|
||||
|
||||
if (!accessToken) {
|
||||
@ -122,7 +123,7 @@ export const useAuth = () => {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
const requestOptions = {
|
||||
const requestOptions: RequestInit & { headers: Record<string, string> } = {
|
||||
...options,
|
||||
headers: { ...defaultHeaders, ...options.headers },
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { useEffect, useCallback } from 'react'
|
||||
import { useAuth } from './useAuth'
|
||||
import { useAuth } from '@/hooks/useAuth'
|
||||
|
||||
const parseServerSentEvents = async (response: Response): Promise<any> => {
|
||||
const reader = response.body?.getReader()
|
||||
@ -91,10 +91,10 @@ const parseJsonRpcResponse = async (response: Response): Promise<any> => {
|
||||
}
|
||||
|
||||
export const useMCP = () => {
|
||||
const { apiCall, isAuthenticated, isLoading, checkAuth } = useAuth()
|
||||
const { apiCall, checkAuth } = useAuth()
|
||||
|
||||
const mcpCall = useCallback(async (path: string, method: string, params: any = {}) => {
|
||||
const response = await apiCall(`/mcp${path}`, {
|
||||
const mcpCall = useCallback(async (method: string, params: any = {}) => {
|
||||
const response = await apiCall(`/mcp/${method}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Accept': 'application/json, text/event-stream',
|
||||
@ -118,22 +118,46 @@ export const useMCP = () => {
|
||||
if (resp?.result?.isError) {
|
||||
throw new Error(resp?.result?.content[0].text)
|
||||
}
|
||||
return resp?.result?.content.map((item: any) => JSON.parse(item.text))
|
||||
return resp?.result?.content.map((item: any) => {
|
||||
try {
|
||||
return JSON.parse(item.text)
|
||||
} catch (e) {
|
||||
return item.text
|
||||
}
|
||||
})
|
||||
}, [apiCall])
|
||||
|
||||
const listNotes = useCallback(async (path: string = "/") => {
|
||||
return await mcpCall('/note_files', 'note_files', { path })
|
||||
return await mcpCall('note_files', { path })
|
||||
}, [mcpCall])
|
||||
|
||||
const fetchFile = useCallback(async (filename: string) => {
|
||||
return await mcpCall('/fetch_file', 'fetch_file', { filename })
|
||||
return await mcpCall('fetch_file', { filename })
|
||||
}, [mcpCall])
|
||||
|
||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10) => {
|
||||
return await mcpCall('/search_knowledge_base', 'search_knowledge_base', {
|
||||
const getTags = useCallback(async () => {
|
||||
return await mcpCall('get_all_tags')
|
||||
}, [mcpCall])
|
||||
|
||||
const getSubjects = useCallback(async () => {
|
||||
return await mcpCall('get_all_subjects')
|
||||
}, [mcpCall])
|
||||
|
||||
const getObservationTypes = useCallback(async () => {
|
||||
return await mcpCall('get_all_observation_types')
|
||||
}, [mcpCall])
|
||||
|
||||
const getMetadataSchemas = useCallback(async () => {
|
||||
return (await mcpCall('get_metadata_schemas'))[0]
|
||||
}, [mcpCall])
|
||||
|
||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
|
||||
return await mcpCall('search_knowledge_base', {
|
||||
query,
|
||||
filters,
|
||||
modalities,
|
||||
previews,
|
||||
limit
|
||||
limit,
|
||||
})
|
||||
}, [mcpCall])
|
||||
|
||||
@ -146,5 +170,9 @@ export const useMCP = () => {
|
||||
fetchFile,
|
||||
listNotes,
|
||||
searchKnowledgeBase,
|
||||
getTags,
|
||||
getSubjects,
|
||||
getObservationTypes,
|
||||
getMetadataSchemas,
|
||||
}
|
||||
}
|
@ -14,7 +14,7 @@ const generateCodeVerifier = () => {
|
||||
.replace(/=/g, '')
|
||||
}
|
||||
|
||||
const generateCodeChallenge = async (verifier) => {
|
||||
const generateCodeChallenge = async (verifier: string) => {
|
||||
const data = new TextEncoder().encode(verifier)
|
||||
const digest = await crypto.subtle.digest('SHA-256', data)
|
||||
return btoa(String.fromCharCode(...new Uint8Array(digest)))
|
||||
@ -33,7 +33,7 @@ const generateState = () => {
|
||||
}
|
||||
|
||||
// Storage utilities
|
||||
const setCookie = (name, value, days = 30) => {
|
||||
const setCookie = (name: string, value: string, days = 30) => {
|
||||
const expires = new Date()
|
||||
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
|
||||
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { StrictMode } from 'react'
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
import App from '@/App.jsx'
|
||||
|
||||
createRoot(document.getElementById('root')).render(
|
||||
<StrictMode>
|
||||
|
9
frontend/src/types/mcp.tsx
Normal file
9
frontend/src/types/mcp.tsx
Normal file
@ -0,0 +1,9 @@
|
||||
export type CollectionMetadata = {
|
||||
schema: Record<string, SchemaArg>
|
||||
size: number
|
||||
}
|
||||
|
||||
export type SchemaArg = {
|
||||
type: string
|
||||
description: string
|
||||
}
|
40
frontend/tsconfig.json
Normal file
40
frontend/tsconfig.json
Normal file
@ -0,0 +1,40 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"useDefineForClassFields": true,
|
||||
"lib": [
|
||||
"ES2020",
|
||||
"DOM",
|
||||
"DOM.Iterable"
|
||||
],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"noEmit": true,
|
||||
"jsx": "react-jsx",
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
/* Absolute imports */
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"@/*": [
|
||||
"src/*"
|
||||
]
|
||||
}
|
||||
},
|
||||
"include": [
|
||||
"src"
|
||||
],
|
||||
"references": [
|
||||
{
|
||||
"path": "./tsconfig.node.json"
|
||||
}
|
||||
]
|
||||
}
|
12
frontend/tsconfig.node.json
Normal file
12
frontend/tsconfig.node.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"composite": true,
|
||||
"skipLibCheck": true,
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"allowSyntheticDefaultImports": true
|
||||
},
|
||||
"include": [
|
||||
"vite.config.js"
|
||||
]
|
||||
}
|
@ -1,8 +1,14 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
import path from 'path'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
base: '/ui/',
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': path.resolve(__dirname, './src'),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
@ -1,2 +1,3 @@
|
||||
import memory.api.MCP.manifest
|
||||
import memory.api.MCP.memory
|
||||
import memory.api.MCP.metadata
|
||||
|
@ -127,7 +127,6 @@ async def handle_login(request: Request):
|
||||
return login_form(request, oauth_params, "Invalid email or password")
|
||||
|
||||
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
||||
print("redirect_url", redirect_url)
|
||||
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
@ -3,23 +3,25 @@ MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
import mimetypes
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Text, func
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.api.search.search import SearchFilters, search
|
||||
from memory.common import extract, settings
|
||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||
from memory.common.celery_app import app as celery_app
|
||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.common.db.models import SourceItem, AgentObservation
|
||||
from memory.common.formatters import observation
|
||||
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
|
||||
from memory.api.MCP.tools import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -47,11 +49,13 @@ def filter_observation_source_ids(
|
||||
return source_ids
|
||||
|
||||
|
||||
def filter_source_ids(
|
||||
modalities: set[str],
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
if not tags:
|
||||
def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] | None:
|
||||
if source_ids := filters.get("source_ids"):
|
||||
return source_ids
|
||||
|
||||
tags = filters.get("tags")
|
||||
size = filters.get("size")
|
||||
if not (tags or size):
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
@ -62,6 +66,8 @@ def filter_source_ids(
|
||||
items_query = items_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if size:
|
||||
items_query = items_query.filter(SourceItem.size == size)
|
||||
if modalities:
|
||||
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
@ -69,51 +75,12 @@ def filter_source_ids(
|
||||
return source_ids
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""
|
||||
Get all unique tags used across the entire knowledge base.
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""
|
||||
Get all unique subjects from observations about the user.
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""
|
||||
Get all observation types that have been used.
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
previews: bool = False,
|
||||
filters: dict[str, Any],
|
||||
modalities: set[str] = set(),
|
||||
tags: list[str] = [],
|
||||
previews: bool = False,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
@ -125,7 +92,7 @@ async def search_knowledge_base(
|
||||
query: Natural language search query - be descriptive about what you're looking for
|
||||
previews: Include actual content in results - when false only a snippet is returned
|
||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||
tags: Filter by tags - content must have at least one matching tag
|
||||
filters: Filter by tags, source_ids, etc.
|
||||
limit: Max results (1-100)
|
||||
|
||||
Returns: List of search results with id, score, chunks, content, filename
|
||||
@ -137,6 +104,9 @@ async def search_knowledge_base(
|
||||
modalities = set(ALL_COLLECTIONS.keys())
|
||||
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
||||
|
||||
search_filters = SearchFilters(**filters)
|
||||
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
|
||||
|
||||
upload_data = extract.extract_text(query)
|
||||
results = await search(
|
||||
upload_data,
|
||||
@ -145,10 +115,7 @@ async def search_knowledge_base(
|
||||
limit=limit,
|
||||
min_text_score=0.4,
|
||||
min_multimodal_score=0.25,
|
||||
filters=SearchFilters(
|
||||
tags=tags,
|
||||
source_ids=filter_source_ids(tags=tags, modalities=modalities),
|
||||
),
|
||||
filters=search_filters,
|
||||
)
|
||||
|
||||
return [result.model_dump() for result in results]
|
||||
|
119
src/memory/api/MCP/metadata.py
Normal file
119
src/memory/api/MCP/metadata.py
Normal file
@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, TypedDict, get_args, get_type_hints
|
||||
|
||||
from memory.common import qdrant
|
||||
from sqlalchemy import func
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.db.models.source_items import AgentObservation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchemaArg(TypedDict):
|
||||
type: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
class CollectionMetadata(TypedDict):
|
||||
schema: dict[str, SchemaArg]
|
||||
size: int
|
||||
|
||||
|
||||
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
||||
try:
|
||||
type_, description = get_args(annotation)
|
||||
type_str = str(type_)
|
||||
if type_str.startswith("typing."):
|
||||
type_str = type_str[7:]
|
||||
elif len((parts := type_str.split("'"))) > 1:
|
||||
type_str = parts[1]
|
||||
return SchemaArg(type=type_str, description=description)
|
||||
except IndexError:
|
||||
logger.error(f"Error from annotation: {annotation}")
|
||||
return None
|
||||
|
||||
|
||||
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
||||
if not hasattr(klass, "as_payload"):
|
||||
return {}
|
||||
|
||||
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
||||
return {}
|
||||
|
||||
return {
|
||||
name: schema
|
||||
for name, arg in payload_type.__annotations__.items()
|
||||
if (schema := from_annotation(arg))
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
||||
"""Get the metadata schema for each collection used in the knowledge base.
|
||||
|
||||
These schemas can be used to filter the knowledge base.
|
||||
|
||||
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
||||
|
||||
Example:
|
||||
```
|
||||
{
|
||||
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
||||
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
||||
}
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
sizes = qdrant.get_collection_sizes(client)
|
||||
schemas = defaultdict(dict)
|
||||
for klass in SourceItem.__subclasses__():
|
||||
for collection in klass.get_collections():
|
||||
schemas[collection].update(get_schema(klass))
|
||||
|
||||
return {
|
||||
collection: CollectionMetadata(schema=schema, size=size)
|
||||
for collection, schema in schemas.items()
|
||||
if (size := sizes.get(collection))
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""Get all unique tags used across the entire knowledge base.
|
||||
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""Get all unique subjects from observations about the user.
|
||||
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""Get all observation types that have been used.
|
||||
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
@ -71,6 +71,13 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||
# SQLAdmin setup with OAuth protection
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
admin.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # [settings.SERVER_URL],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Setup admin with OAuth protection using existing OAuth provider
|
||||
setup_admin(admin)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import qdrant_client
|
||||
from PIL import Image
|
||||
@ -90,13 +90,71 @@ def query_chunks(
|
||||
}
|
||||
|
||||
|
||||
def merge_range_filter(
|
||||
filters: list[dict[str, Any]], key: str, val: Any
|
||||
) -> list[dict[str, Any]]:
|
||||
direction, field = key.split("_", maxsplit=1)
|
||||
item = next((f for f in filters if f["key"] == field), None)
|
||||
if not item:
|
||||
item = {"key": field, "range": {}}
|
||||
filters.append(item)
|
||||
|
||||
if direction == "min":
|
||||
item["range"]["gte"] = val
|
||||
elif direction == "max":
|
||||
item["range"]["lte"] = val
|
||||
return filters
|
||||
|
||||
|
||||
def merge_filters(
|
||||
filters: list[dict[str, Any]], key: str, val: Any
|
||||
) -> list[dict[str, Any]]:
|
||||
if not val and val != 0:
|
||||
return filters
|
||||
|
||||
list_filters = ["tags", "recipients", "observation_types", "authors"]
|
||||
range_filters = [
|
||||
"min_sent_at",
|
||||
"max_sent_at",
|
||||
"min_published",
|
||||
"max_published",
|
||||
"min_size",
|
||||
"max_size",
|
||||
"min_created_at",
|
||||
"max_created_at",
|
||||
]
|
||||
if key in list_filters:
|
||||
filters.append({"key": key, "match": {"any": val}})
|
||||
|
||||
elif key in range_filters:
|
||||
return merge_range_filter(filters, key, val)
|
||||
|
||||
elif key == "min_confidences":
|
||||
confidence_filters = [
|
||||
{
|
||||
"key": f"confidence.{confidence_type}",
|
||||
"range": {"gte": min_confidence_score},
|
||||
}
|
||||
for confidence_type, min_confidence_score in cast(dict, val).items()
|
||||
]
|
||||
filters.extend(confidence_filters)
|
||||
|
||||
elif key == "source_ids":
|
||||
filters.append({"key": "id", "match": {"any": val}})
|
||||
|
||||
else:
|
||||
filters.append({"key": key, "match": val})
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
async def search_embeddings(
|
||||
data: list[extract.DataChunk],
|
||||
previews: Optional[bool] = False,
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
min_score: float = 0.3,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
filters: SearchFilters = {},
|
||||
multimodal: bool = False,
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
@ -111,27 +169,11 @@ async def search_embeddings(
|
||||
- filters: Filters to apply to the search results
|
||||
- multimodal: Whether to search in multimodal collections
|
||||
"""
|
||||
query_filters = {"must": []}
|
||||
|
||||
# Handle structured confidence filtering
|
||||
if min_confidences := filters.get("min_confidences"):
|
||||
confidence_filters = [
|
||||
{
|
||||
"key": f"confidence.{confidence_type}",
|
||||
"range": {"gte": min_confidence_score},
|
||||
}
|
||||
for confidence_type, min_confidence_score in min_confidences.items()
|
||||
]
|
||||
query_filters["must"].extend(confidence_filters)
|
||||
|
||||
if tags := filters.get("tags"):
|
||||
query_filters["must"].append({"key": "tags", "match": {"any": tags}})
|
||||
|
||||
if observation_types := filters.get("observation_types"):
|
||||
query_filters["must"].append(
|
||||
{"key": "observation_type", "match": {"any": observation_types}}
|
||||
)
|
||||
search_filters = []
|
||||
for key, val in filters.items():
|
||||
search_filters = merge_filters(search_filters, key, val)
|
||||
|
||||
print(search_filters)
|
||||
client = qdrant.get_qdrant_client()
|
||||
results = query_chunks(
|
||||
client,
|
||||
@ -140,7 +182,7 @@ async def search_embeddings(
|
||||
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||
min_score=min_score,
|
||||
limit=limit,
|
||||
filters=query_filters if query_filters["must"] else None,
|
||||
filters={"must": search_filters} if search_filters else None,
|
||||
)
|
||||
search_results = {k: results.get(k, []) for k in modalities}
|
||||
|
||||
|
@ -17,6 +17,7 @@ from memory.common.collections import (
|
||||
MULTIMODAL_COLLECTIONS,
|
||||
TEXT_COLLECTIONS,
|
||||
)
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -44,51 +45,57 @@ async def search(
|
||||
- List of search results sorted by score
|
||||
"""
|
||||
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
|
||||
print(allowed_modalities)
|
||||
|
||||
text_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
min_text_score,
|
||||
filters,
|
||||
multimodal=False,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
multimodal_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
min_multimodal_score,
|
||||
filters,
|
||||
multimodal=True,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
bm25_results = with_timeout(
|
||||
search_bm25(
|
||||
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
||||
modalities,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
searches = []
|
||||
if settings.ENABLE_EMBEDDING_SEARCH:
|
||||
searches = [
|
||||
with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
min_text_score,
|
||||
filters,
|
||||
multimodal=False,
|
||||
),
|
||||
timeout,
|
||||
),
|
||||
with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
min_multimodal_score,
|
||||
filters,
|
||||
multimodal=True,
|
||||
),
|
||||
timeout,
|
||||
),
|
||||
]
|
||||
if settings.ENABLE_BM25_SEARCH:
|
||||
searches.append(
|
||||
with_timeout(
|
||||
search_bm25(
|
||||
" ".join(
|
||||
[c for chunk in data for c in chunk.data if isinstance(c, str)]
|
||||
),
|
||||
modalities,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
text_embeddings_results,
|
||||
multimodal_embeddings_results,
|
||||
bm25_results,
|
||||
return_exceptions=False,
|
||||
)
|
||||
text_results, multi_results, bm25_results = results
|
||||
all_results = text_results + multi_results
|
||||
if len(all_results) < limit:
|
||||
all_results += bm25_results
|
||||
search_results = await asyncio.gather(*searches, return_exceptions=False)
|
||||
all_results = []
|
||||
for results in search_results:
|
||||
if len(all_results) >= limit:
|
||||
break
|
||||
all_results.extend(results)
|
||||
|
||||
results = group_chunks(all_results, previews or False)
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||
|
@ -65,9 +65,9 @@ class SearchResult(BaseModel):
|
||||
|
||||
|
||||
class SearchFilters(TypedDict):
|
||||
subject: NotRequired[str | None]
|
||||
min_size: NotRequired[int]
|
||||
max_size: NotRequired[int]
|
||||
min_confidences: NotRequired[dict[str, float]]
|
||||
tags: NotRequired[list[str] | None]
|
||||
observation_types: NotRequired[list[str] | None]
|
||||
source_ids: NotRequired[list[int] | None]
|
||||
|
||||
@ -115,7 +115,6 @@ def group_chunks(
|
||||
if isinstance(contents, dict):
|
||||
tags = contents.pop("tags", [])
|
||||
content = contents.pop("content", None)
|
||||
print(content)
|
||||
else:
|
||||
content = contents
|
||||
contents = {}
|
||||
|
@ -4,6 +4,7 @@ from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
ConfidenceScore,
|
||||
clean_filename,
|
||||
SourceItemPayload,
|
||||
)
|
||||
from memory.common.db.models.source_items import (
|
||||
MailMessage,
|
||||
@ -19,6 +20,14 @@ from memory.common.db.models.source_items import (
|
||||
Photo,
|
||||
MiscDoc,
|
||||
Note,
|
||||
MailMessagePayload,
|
||||
EmailAttachmentPayload,
|
||||
AgentObservationPayload,
|
||||
BlogPostPayload,
|
||||
ComicPayload,
|
||||
BookSectionPayload,
|
||||
NotePayload,
|
||||
ForumPostPayload,
|
||||
)
|
||||
from memory.common.db.models.observations import (
|
||||
ObservationContradiction,
|
||||
@ -40,6 +49,18 @@ from memory.common.db.models.users import (
|
||||
OAuthRefreshToken,
|
||||
)
|
||||
|
||||
Payload = (
|
||||
SourceItemPayload
|
||||
| AgentObservationPayload
|
||||
| NotePayload
|
||||
| BlogPostPayload
|
||||
| ComicPayload
|
||||
| BookSectionPayload
|
||||
| ForumPostPayload
|
||||
| EmailAttachmentPayload
|
||||
| MailMessagePayload
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"Chunk",
|
||||
@ -75,4 +96,6 @@ __all__ = [
|
||||
"OAuthClientInformation",
|
||||
"OAuthState",
|
||||
"OAuthRefreshToken",
|
||||
# Payloads
|
||||
"Payload",
|
||||
]
|
||||
|
@ -4,7 +4,7 @@ Database models for the knowledge base system.
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any, Sequence, cast
|
||||
from typing import Any, Annotated, Sequence, TypedDict, cast
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
@ -36,6 +36,17 @@ import memory.common.summarizer as summarizer
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class MetadataSchema(TypedDict):
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
class SourceItemPayload(TypedDict):
|
||||
source_id: Annotated[int, "Unique identifier of the source item"]
|
||||
tags: Annotated[list[str], "List of tags for categorization"]
|
||||
size: Annotated[int | None, "Size of the content in bytes"]
|
||||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def handle_duplicate_sha256(session, flush_context, instances):
|
||||
"""
|
||||
@ -344,12 +355,17 @@ class SourceItem(Base):
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
||||
def as_payload(self) -> SourceItemPayload:
|
||||
return SourceItemPayload(
|
||||
source_id=cast(int, self.id),
|
||||
tags=cast(list[str], self.tags),
|
||||
size=cast(int | None, self.size),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
"""Return the list of Qdrant collections this SourceItem type can be stored in."""
|
||||
return [cls.__tablename__]
|
||||
|
||||
@property
|
||||
def display_contents(self) -> str | dict | None:
|
||||
|
@ -5,7 +5,7 @@ Database models for the knowledge base system.
|
||||
import pathlib
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, Sequence, cast
|
||||
from typing import Any, Annotated, Sequence, cast
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
@ -32,11 +32,21 @@ import memory.common.formatters.observation as observation
|
||||
from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
Chunk,
|
||||
SourceItemPayload,
|
||||
clean_filename,
|
||||
chunk_mixed,
|
||||
)
|
||||
|
||||
|
||||
class MailMessagePayload(SourceItemPayload):
|
||||
message_id: Annotated[str, "Unique email message identifier"]
|
||||
subject: Annotated[str, "Email subject line"]
|
||||
sender: Annotated[str, "Email sender address"]
|
||||
recipients: Annotated[list[str], "List of recipient email addresses"]
|
||||
folder: Annotated[str, "Email folder name"]
|
||||
date: Annotated[str | None, "Email sent date in ISO format"]
|
||||
|
||||
|
||||
class MailMessage(SourceItem):
|
||||
__tablename__ = "mail_message"
|
||||
|
||||
@ -80,17 +90,21 @@ class MailMessage(SourceItem):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
**super().as_payload(),
|
||||
"message_id": self.message_id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
"recipients": self.recipients,
|
||||
"folder": self.folder,
|
||||
"tags": self.tags + [self.sender] + self.recipients,
|
||||
"date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
|
||||
def as_payload(self) -> MailMessagePayload:
|
||||
base_payload = super().as_payload() | {
|
||||
"tags": cast(list[str], self.tags)
|
||||
+ [cast(str, self.sender)]
|
||||
+ cast(list[str], self.recipients)
|
||||
}
|
||||
return MailMessagePayload(
|
||||
**cast(dict, base_payload),
|
||||
message_id=cast(str, self.message_id),
|
||||
subject=cast(str, self.subject),
|
||||
sender=cast(str, self.sender),
|
||||
recipients=cast(list[str], self.recipients),
|
||||
folder=cast(str, self.folder),
|
||||
date=(self.sent_at and self.sent_at.isoformat() or None), # type: ignore
|
||||
)
|
||||
|
||||
@property
|
||||
def parsed_content(self) -> dict[str, Any]:
|
||||
@ -152,7 +166,7 @@ class MailMessage(SourceItem):
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
content = self.parsed_content
|
||||
chunks = extract.extract_text(cast(str, self.body))
|
||||
chunks = extract.extract_text(cast(str, self.body), modality="email")
|
||||
|
||||
def add_header(item: extract.MulitmodalChunk) -> extract.MulitmodalChunk:
|
||||
if isinstance(item, str):
|
||||
@ -163,6 +177,10 @@ class MailMessage(SourceItem):
|
||||
chunk.data = [add_header(item) for item in chunk.data]
|
||||
return chunks
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["mail"]
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("mail_sent_idx", "sent_at"),
|
||||
@ -171,6 +189,13 @@ class MailMessage(SourceItem):
|
||||
)
|
||||
|
||||
|
||||
class EmailAttachmentPayload(SourceItemPayload):
|
||||
filename: Annotated[str, "Name of the document file"]
|
||||
content_type: Annotated[str, "MIME type of the document"]
|
||||
mail_message_id: Annotated[int, "Associated email message ID"]
|
||||
sent_at: Annotated[str | None, "Document creation timestamp"]
|
||||
|
||||
|
||||
class EmailAttachment(SourceItem):
|
||||
__tablename__ = "email_attachment"
|
||||
|
||||
@ -190,17 +215,20 @@ class EmailAttachment(SourceItem):
|
||||
"polymorphic_identity": "email_attachment",
|
||||
}
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
def as_payload(self) -> EmailAttachmentPayload:
|
||||
return EmailAttachmentPayload(
|
||||
**super().as_payload(),
|
||||
"filename": self.filename,
|
||||
"content_type": self.mime_type,
|
||||
"size": self.size,
|
||||
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
|
||||
"mail_message_id": self.mail_message_id,
|
||||
}
|
||||
filename=cast(str, self.filename),
|
||||
content_type=cast(str, self.mime_type),
|
||||
mail_message_id=cast(int, self.mail_message_id),
|
||||
sent_at=(
|
||||
self.mail_message.sent_at
|
||||
and self.mail_message.sent_at.isoformat()
|
||||
or None
|
||||
), # type: ignore
|
||||
)
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
if cast(str | None, self.filename):
|
||||
contents = (
|
||||
settings.FILE_STORAGE_DIR / cast(str, self.filename)
|
||||
@ -208,8 +236,7 @@ class EmailAttachment(SourceItem):
|
||||
else:
|
||||
contents = cast(str, self.content)
|
||||
|
||||
chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents)
|
||||
return [self._make_chunk(c, metadata) for c in chunks]
|
||||
return extract.extract_data_chunks(cast(str, self.mime_type), contents)
|
||||
|
||||
@property
|
||||
def display_contents(self) -> dict:
|
||||
@ -221,6 +248,11 @@ class EmailAttachment(SourceItem):
|
||||
# Add indexes
|
||||
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
"""EmailAttachment can go to different collections based on mime_type"""
|
||||
return ["doc", "text", "blog", "photo", "book"]
|
||||
|
||||
|
||||
class ChatMessage(SourceItem):
|
||||
__tablename__ = "chat_message"
|
||||
@ -285,6 +317,16 @@ class Photo(SourceItem):
|
||||
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
|
||||
|
||||
|
||||
class ComicPayload(SourceItemPayload):
|
||||
title: Annotated[str, "Title of the comic"]
|
||||
author: Annotated[str | None, "Author of the comic"]
|
||||
published: Annotated[str | None, "Publication date in ISO format"]
|
||||
volume: Annotated[str | None, "Volume number"]
|
||||
issue: Annotated[str | None, "Issue number"]
|
||||
page: Annotated[int | None, "Page number"]
|
||||
url: Annotated[str | None, "URL of the comic"]
|
||||
|
||||
|
||||
class Comic(SourceItem):
|
||||
__tablename__ = "comic"
|
||||
|
||||
@ -305,18 +347,17 @@ class Comic(SourceItem):
|
||||
|
||||
__table_args__ = (Index("comic_author_idx", "author"),)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
payload = {
|
||||
def as_payload(self) -> ComicPayload:
|
||||
return ComicPayload(
|
||||
**super().as_payload(),
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"published": self.published,
|
||||
"volume": self.volume,
|
||||
"issue": self.issue,
|
||||
"page": self.page,
|
||||
"url": self.url,
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v is not None}
|
||||
title=cast(str, self.title),
|
||||
author=cast(str | None, self.author),
|
||||
published=(self.published and self.published.isoformat() or None), # type: ignore
|
||||
volume=cast(str | None, self.volume),
|
||||
issue=cast(str | None, self.issue),
|
||||
page=cast(int | None, self.page),
|
||||
url=cast(str | None, self.url),
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename))
|
||||
@ -324,6 +365,17 @@ class Comic(SourceItem):
|
||||
return [extract.DataChunk(data=[image, description])]
|
||||
|
||||
|
||||
class BookSectionPayload(SourceItemPayload):
|
||||
title: Annotated[str, "Title of the book"]
|
||||
author: Annotated[str | None, "Author of the book"]
|
||||
book_id: Annotated[int, "Unique identifier of the book"]
|
||||
section_title: Annotated[str, "Title of the section"]
|
||||
section_number: Annotated[int, "Number of the section"]
|
||||
section_level: Annotated[int, "Level of the section"]
|
||||
start_page: Annotated[int, "Starting page number"]
|
||||
end_page: Annotated[int, "Ending page number"]
|
||||
|
||||
|
||||
class BookSection(SourceItem):
|
||||
"""Individual sections/chapters of books"""
|
||||
|
||||
@ -361,19 +413,22 @@ class BookSection(SourceItem):
|
||||
Index("book_section_level_idx", "section_level", "section_number"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
vals = {
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["book"]
|
||||
|
||||
def as_payload(self) -> BookSectionPayload:
|
||||
return BookSectionPayload(
|
||||
**super().as_payload(),
|
||||
"title": self.book.title,
|
||||
"author": self.book.author,
|
||||
"book_id": self.book_id,
|
||||
"section_title": self.section_title,
|
||||
"section_number": self.section_number,
|
||||
"section_level": self.section_level,
|
||||
"start_page": self.start_page,
|
||||
"end_page": self.end_page,
|
||||
}
|
||||
return {k: v for k, v in vals.items() if v}
|
||||
title=cast(str, self.book.title),
|
||||
author=cast(str | None, self.book.author),
|
||||
book_id=cast(int, self.book_id),
|
||||
section_title=cast(str, self.section_title),
|
||||
section_number=cast(int, self.section_number),
|
||||
section_level=cast(int, self.section_level),
|
||||
start_page=cast(int, self.start_page),
|
||||
end_page=cast(int, self.end_page),
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
content = cast(str, self.content.strip())
|
||||
@ -397,6 +452,16 @@ class BookSection(SourceItem):
|
||||
]
|
||||
|
||||
|
||||
class BlogPostPayload(SourceItemPayload):
|
||||
url: Annotated[str, "URL of the blog post"]
|
||||
title: Annotated[str, "Title of the blog post"]
|
||||
author: Annotated[str | None, "Author of the blog post"]
|
||||
published: Annotated[str | None, "Publication date in ISO format"]
|
||||
description: Annotated[str | None, "Description of the blog post"]
|
||||
domain: Annotated[str | None, "Domain of the blog post"]
|
||||
word_count: Annotated[int | None, "Word count of the blog post"]
|
||||
|
||||
|
||||
class BlogPost(SourceItem):
|
||||
__tablename__ = "blog_post"
|
||||
|
||||
@ -428,27 +493,39 @@ class BlogPost(SourceItem):
|
||||
Index("blog_post_word_count_idx", "word_count"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
def as_payload(self) -> BlogPostPayload:
|
||||
published_date = cast(datetime | None, self.published)
|
||||
metadata = cast(dict | None, self.webpage_metadata) or {}
|
||||
|
||||
payload = {
|
||||
return BlogPostPayload(
|
||||
**super().as_payload(),
|
||||
"url": self.url,
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"published": published_date and published_date.isoformat(),
|
||||
"description": self.description,
|
||||
"domain": self.domain,
|
||||
"word_count": self.word_count,
|
||||
url=cast(str, self.url),
|
||||
title=cast(str, self.title),
|
||||
author=cast(str | None, self.author),
|
||||
published=(published_date and published_date.isoformat() or None), # type: ignore
|
||||
description=cast(str | None, self.description),
|
||||
domain=cast(str | None, self.domain),
|
||||
word_count=cast(int | None, self.word_count),
|
||||
**metadata,
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v}
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||
|
||||
|
||||
class ForumPostPayload(SourceItemPayload):
|
||||
url: Annotated[str, "URL of the forum post"]
|
||||
title: Annotated[str, "Title of the forum post"]
|
||||
description: Annotated[str | None, "Description of the forum post"]
|
||||
authors: Annotated[list[str] | None, "Authors of the forum post"]
|
||||
published: Annotated[str | None, "Publication date in ISO format"]
|
||||
slug: Annotated[str | None, "Slug of the forum post"]
|
||||
karma: Annotated[int | None, "Karma score of the forum post"]
|
||||
votes: Annotated[int | None, "Number of votes on the forum post"]
|
||||
score: Annotated[int | None, "Score of the forum post"]
|
||||
comments: Annotated[int | None, "Number of comments on the forum post"]
|
||||
|
||||
|
||||
class ForumPost(SourceItem):
|
||||
__tablename__ = "forum_post"
|
||||
|
||||
@ -479,20 +556,20 @@ class ForumPost(SourceItem):
|
||||
Index("forum_post_title_idx", "title"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
def as_payload(self) -> ForumPostPayload:
|
||||
return ForumPostPayload(
|
||||
**super().as_payload(),
|
||||
"url": self.url,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"authors": self.authors,
|
||||
"published_at": self.published_at,
|
||||
"slug": self.slug,
|
||||
"karma": self.karma,
|
||||
"votes": self.votes,
|
||||
"score": self.score,
|
||||
"comments": self.comments,
|
||||
}
|
||||
url=cast(str, self.url),
|
||||
title=cast(str, self.title),
|
||||
description=cast(str | None, self.description),
|
||||
authors=cast(list[str] | None, self.authors),
|
||||
published=(self.published_at and self.published_at.isoformat() or None), # type: ignore
|
||||
slug=cast(str | None, self.slug),
|
||||
karma=cast(int | None, self.karma),
|
||||
votes=cast(int | None, self.votes),
|
||||
score=cast(int | None, self.score),
|
||||
comments=cast(int | None, self.comments),
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||
@ -545,6 +622,12 @@ class GithubItem(SourceItem):
|
||||
)
|
||||
|
||||
|
||||
class NotePayload(SourceItemPayload):
|
||||
note_type: Annotated[str | None, "Category of the note"]
|
||||
subject: Annotated[str | None, "What the note is about"]
|
||||
confidence: Annotated[dict[str, float], "Confidence scores for the note"]
|
||||
|
||||
|
||||
class Note(SourceItem):
|
||||
"""A quick note of something of interest."""
|
||||
|
||||
@ -565,13 +648,13 @@ class Note(SourceItem):
|
||||
Index("note_subject_idx", "subject"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
def as_payload(self) -> NotePayload:
|
||||
return NotePayload(
|
||||
**super().as_payload(),
|
||||
"note_type": self.note_type,
|
||||
"subject": self.subject,
|
||||
"confidence": self.confidence_dict,
|
||||
}
|
||||
note_type=cast(str | None, self.note_type),
|
||||
subject=cast(str | None, self.subject),
|
||||
confidence=self.confidence_dict,
|
||||
)
|
||||
|
||||
@property
|
||||
def display_contents(self) -> dict:
|
||||
@ -602,6 +685,19 @@ class Note(SourceItem):
|
||||
self.as_text(cast(str, self.content), cast(str | None, self.subject))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["text"] # Notes go to the text collection
|
||||
|
||||
|
||||
class AgentObservationPayload(SourceItemPayload):
|
||||
session_id: Annotated[str | None, "Session ID for the observation"]
|
||||
observation_type: Annotated[str, "Type of observation"]
|
||||
subject: Annotated[str, "What/who the observation is about"]
|
||||
confidence: Annotated[dict[str, float], "Confidence scores for the observation"]
|
||||
evidence: Annotated[dict | None, "Supporting context, quotes, etc."]
|
||||
agent_model: Annotated[str, "Which AI model made this observation"]
|
||||
|
||||
|
||||
class AgentObservation(SourceItem):
|
||||
"""
|
||||
@ -652,18 +748,16 @@ class AgentObservation(SourceItem):
|
||||
kwargs["modality"] = "observation"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
payload = {
|
||||
def as_payload(self) -> AgentObservationPayload:
|
||||
return AgentObservationPayload(
|
||||
**super().as_payload(),
|
||||
"observation_type": self.observation_type,
|
||||
"subject": self.subject,
|
||||
"confidence": self.confidence_dict,
|
||||
"evidence": self.evidence,
|
||||
"agent_model": self.agent_model,
|
||||
}
|
||||
if self.session_id is not None:
|
||||
payload["session_id"] = str(self.session_id)
|
||||
return payload
|
||||
observation_type=cast(str, self.observation_type),
|
||||
subject=cast(str, self.subject),
|
||||
confidence=self.confidence_dict,
|
||||
evidence=cast(dict | None, self.evidence),
|
||||
agent_model=cast(str, self.agent_model),
|
||||
session_id=cast(str | None, self.session_id) and str(self.session_id),
|
||||
)
|
||||
|
||||
@property
|
||||
def all_contradictions(self):
|
||||
@ -759,3 +853,7 @@ class AgentObservation(SourceItem):
|
||||
# ))
|
||||
|
||||
return chunks
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["semantic", "temporal"]
|
||||
|
@ -53,7 +53,9 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
def doc_to_images(
|
||||
content: bytes | str | pathlib.Path, modality: str = "doc"
|
||||
) -> list[DataChunk]:
|
||||
with as_file(content) as file_path:
|
||||
with pymupdf.open(file_path) as pdf:
|
||||
return [
|
||||
@ -65,6 +67,7 @@ def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
"height": page.rect.height,
|
||||
},
|
||||
mime_type="image/jpeg",
|
||||
modality=modality,
|
||||
)
|
||||
for page in pdf.pages()
|
||||
]
|
||||
@ -122,6 +125,7 @@ def extract_text(
|
||||
content: bytes | str | pathlib.Path,
|
||||
chunk_size: int | None = None,
|
||||
metadata: dict[str, Any] = {},
|
||||
modality: str = "text",
|
||||
) -> list[DataChunk]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
content = content.read_text()
|
||||
@ -130,7 +134,7 @@ def extract_text(
|
||||
|
||||
content = cast(str, content)
|
||||
chunks = [
|
||||
DataChunk(data=[c], modality="text", metadata=metadata)
|
||||
DataChunk(data=[c], modality=modality, metadata=metadata)
|
||||
for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
|
||||
]
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
@ -139,7 +143,7 @@ def extract_text(
|
||||
DataChunk(
|
||||
data=[summary],
|
||||
metadata=merge_metadata(metadata, {"tags": tags}),
|
||||
modality="text",
|
||||
modality=modality,
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
@ -158,9 +162,7 @@ def extract_data_chunks(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/msword",
|
||||
]:
|
||||
logger.info(f"Extracting content from {content}")
|
||||
chunks = extract_docx(content)
|
||||
logger.info(f"Extracted {len(chunks)} pages from {content}")
|
||||
elif mime_type.startswith("text/"):
|
||||
chunks = extract_text(content, chunk_size)
|
||||
elif mime_type.startswith("image/"):
|
||||
|
@ -224,6 +224,15 @@ def get_collection_info(
|
||||
return info.model_dump()
|
||||
|
||||
|
||||
def get_collection_sizes(client: qdrant_client.QdrantClient) -> dict[str, int]:
|
||||
"""Get the size of each collection."""
|
||||
collections = [i.name for i in client.get_collections().collections]
|
||||
return {
|
||||
collection_name: client.count(collection_name).count # type: ignore
|
||||
for collection_name in collections
|
||||
}
|
||||
|
||||
|
||||
def batch_ids(
|
||||
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
|
||||
) -> Generator[list[str], None, None]:
|
||||
|
@ -6,6 +6,8 @@ load_dotenv()
|
||||
|
||||
|
||||
def boolean_env(key: str, default: bool = False) -> bool:
|
||||
if key not in os.environ:
|
||||
return default
|
||||
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
@ -130,6 +132,10 @@ else:
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
|
||||
# Search settings
|
||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||
|
||||
# API settings
|
||||
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||
HTTPS = boolean_env("HTTPS", False)
|
||||
|
174
tests/memory/api/search/test_search_embeddings.py
Normal file
174
tests/memory/api/search/test_search_embeddings.py
Normal file
@ -0,0 +1,174 @@
|
||||
import pytest
|
||||
from memory.api.search.embeddings import merge_range_filter, merge_filters
|
||||
|
||||
|
||||
def test_merge_range_filter_new_filter():
|
||||
"""Test adding new range filters"""
|
||||
filters = []
|
||||
result = merge_range_filter(filters, "min_size", 100)
|
||||
assert result == [{"key": "size", "range": {"gte": 100}}]
|
||||
|
||||
filters = []
|
||||
result = merge_range_filter(filters, "max_size", 1000)
|
||||
assert result == [{"key": "size", "range": {"lte": 1000}}]
|
||||
|
||||
|
||||
def test_merge_range_filter_existing_field():
|
||||
"""Test adding to existing field"""
|
||||
filters = [{"key": "size", "range": {"lte": 1000}}]
|
||||
result = merge_range_filter(filters, "min_size", 100)
|
||||
assert result == [{"key": "size", "range": {"lte": 1000, "gte": 100}}]
|
||||
|
||||
|
||||
def test_merge_range_filter_override_existing():
|
||||
"""Test overriding existing values"""
|
||||
filters = [{"key": "size", "range": {"gte": 100}}]
|
||||
result = merge_range_filter(filters, "min_size", 200)
|
||||
assert result == [{"key": "size", "range": {"gte": 200}}]
|
||||
|
||||
|
||||
def test_merge_range_filter_with_other_filters():
|
||||
"""Test adding range filter alongside other filters"""
|
||||
filters = [{"key": "tags", "match": {"any": ["tag1"]}}]
|
||||
result = merge_range_filter(filters, "min_size", 100)
|
||||
|
||||
expected = [
|
||||
{"key": "tags", "match": {"any": ["tag1"]}},
|
||||
{"key": "size", "range": {"gte": 100}},
|
||||
]
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key,expected_direction,expected_field",
|
||||
[
|
||||
("min_sent_at", "min", "sent_at"),
|
||||
("max_sent_at", "max", "sent_at"),
|
||||
("min_published", "min", "published"),
|
||||
("max_published", "max", "published"),
|
||||
("min_size", "min", "size"),
|
||||
("max_size", "max", "size"),
|
||||
],
|
||||
)
|
||||
def test_merge_range_filter_key_parsing(key, expected_direction, expected_field):
|
||||
"""Test that field names are correctly extracted from keys"""
|
||||
filters = []
|
||||
result = merge_range_filter(filters, key, 100)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["key"] == expected_field
|
||||
range_key = "gte" if expected_direction == "min" else "lte"
|
||||
assert result[0]["range"][range_key] == 100
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_key,filter_value",
|
||||
[
|
||||
("tags", ["tag1", "tag2"]),
|
||||
("recipients", ["user1", "user2"]),
|
||||
("observation_types", ["belief", "preference"]),
|
||||
("authors", ["author1"]),
|
||||
],
|
||||
)
|
||||
def test_merge_filters_list_filters(filter_key, filter_value):
|
||||
"""Test list filters that use match any"""
|
||||
filters = []
|
||||
result = merge_filters(filters, filter_key, filter_value)
|
||||
expected = [{"key": filter_key, "match": {"any": filter_value}}]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_merge_filters_min_confidences():
|
||||
"""Test min_confidences filter creates multiple range conditions"""
|
||||
filters = []
|
||||
confidences = {"observation_accuracy": 0.8, "source_reliability": 0.9}
|
||||
result = merge_filters(filters, "min_confidences", confidences)
|
||||
|
||||
expected = [
|
||||
{"key": "confidence.observation_accuracy", "range": {"gte": 0.8}},
|
||||
{"key": "confidence.source_reliability", "range": {"gte": 0.9}},
|
||||
]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_merge_filters_source_ids():
|
||||
"""Test source_ids filter maps to id field"""
|
||||
filters = []
|
||||
result = merge_filters(filters, "source_ids", ["id1", "id2"])
|
||||
expected = [{"key": "id", "match": {"any": ["id1", "id2"]}}]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_merge_filters_range_delegation():
|
||||
"""Test range filters are properly delegated to merge_range_filter"""
|
||||
filters = []
|
||||
result = merge_filters(filters, "min_size", 100)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "range" in result[0]
|
||||
assert result[0]["range"]["gte"] == 100
|
||||
|
||||
|
||||
def test_merge_filters_combined_range():
|
||||
"""Test that min/max range pairs merge into single filter"""
|
||||
filters = []
|
||||
filters = merge_filters(filters, "min_size", 100)
|
||||
filters = merge_filters(filters, "max_size", 1000)
|
||||
|
||||
size_filters = [f for f in filters if f["key"] == "size"]
|
||||
assert len(size_filters) == 1
|
||||
assert size_filters[0]["range"]["gte"] == 100
|
||||
assert size_filters[0]["range"]["lte"] == 1000
|
||||
|
||||
|
||||
def test_merge_filters_preserves_existing():
|
||||
"""Test that existing filters are preserved when adding new ones"""
|
||||
existing_filters = [{"key": "existing", "match": "value"}]
|
||||
result = merge_filters(existing_filters, "tags", ["new_tag"])
|
||||
|
||||
assert len(result) == 2
|
||||
assert {"key": "existing", "match": "value"} in result
|
||||
assert {"key": "tags", "match": {"any": ["new_tag"]}} in result
|
||||
|
||||
|
||||
def test_merge_filters_realistic_combination():
|
||||
"""Test a realistic filter combination for knowledge base search"""
|
||||
filters = []
|
||||
|
||||
# Add typical knowledge base filters
|
||||
filters = merge_filters(filters, "tags", ["important", "work"])
|
||||
filters = merge_filters(filters, "min_published", "2023-01-01")
|
||||
filters = merge_filters(filters, "max_size", 1000000) # 1MB max
|
||||
filters = merge_filters(filters, "min_confidences", {"observation_accuracy": 0.8})
|
||||
|
||||
assert len(filters) == 4
|
||||
|
||||
# Check each filter type
|
||||
tag_filter = next(f for f in filters if f["key"] == "tags")
|
||||
assert tag_filter["match"]["any"] == ["important", "work"]
|
||||
|
||||
published_filter = next(f for f in filters if f["key"] == "published")
|
||||
assert published_filter["range"]["gte"] == "2023-01-01"
|
||||
|
||||
size_filter = next(f for f in filters if f["key"] == "size")
|
||||
assert size_filter["range"]["lte"] == 1000000
|
||||
|
||||
confidence_filter = next(
|
||||
f for f in filters if f["key"] == "confidence.observation_accuracy"
|
||||
)
|
||||
assert confidence_filter["range"]["gte"] == 0.8
|
||||
|
||||
|
||||
def test_merge_filters_unknown_key():
|
||||
"""Test fallback behavior for unknown filter keys"""
|
||||
filters = []
|
||||
result = merge_filters(filters, "unknown_field", "unknown_value")
|
||||
expected = [{"key": "unknown_field", "match": "unknown_value"}]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_merge_filters_empty_min_confidences():
|
||||
"""Test min_confidences with empty dict does nothing"""
|
||||
filters = []
|
||||
result = merge_filters(filters, "min_confidences", {})
|
||||
assert result == []
|
Loading…
x
Reference in New Issue
Block a user