mirror of
https://github.com/mruwnik/memory.git
synced 2025-06-28 15:14:45 +02:00
search filters
This commit is contained in:
parent
780e27ba04
commit
3e4e5872d1
@ -285,6 +285,7 @@ body {
|
|||||||
display: flex;
|
display: flex;
|
||||||
gap: 1rem;
|
gap: 1rem;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
.search-input {
|
.search-input {
|
||||||
@ -323,6 +324,34 @@ body {
|
|||||||
cursor: not-allowed;
|
cursor: not-allowed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.search-options {
|
||||||
|
border-top: 1px solid #e2e8f0;
|
||||||
|
padding-top: 1.5rem;
|
||||||
|
display: grid;
|
||||||
|
gap: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-option {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-option label {
|
||||||
|
font-weight: 500;
|
||||||
|
color: #4a5568;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-option input[type="checkbox"] {
|
||||||
|
margin-right: 0.5rem;
|
||||||
|
transform: scale(1.1);
|
||||||
|
accent-color: #667eea;
|
||||||
|
}
|
||||||
|
|
||||||
/* Search Results */
|
/* Search Results */
|
||||||
.search-results {
|
.search-results {
|
||||||
margin-top: 2rem;
|
margin-top: 2rem;
|
||||||
@ -429,6 +458,11 @@ body {
|
|||||||
transition: background-color 0.2s;
|
transition: background-color 0.2s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.tag.selected {
|
||||||
|
background: #667eea;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
.tag:hover {
|
.tag:hover {
|
||||||
background: #e2e8f0;
|
background: #e2e8f0;
|
||||||
}
|
}
|
||||||
@ -495,6 +529,14 @@ body {
|
|||||||
padding: 1.5rem;
|
padding: 1.5rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.search-options {
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.limit-input {
|
||||||
|
width: 80px;
|
||||||
|
}
|
||||||
|
|
||||||
.search-result-card {
|
.search-result-card {
|
||||||
padding: 1rem;
|
padding: 1rem;
|
||||||
}
|
}
|
||||||
@ -689,4 +731,194 @@ body {
|
|||||||
|
|
||||||
.markdown-preview a:hover {
|
.markdown-preview a:hover {
|
||||||
color: #2b6cb0;
|
color: #2b6cb0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Dynamic filters styles */
|
||||||
|
.modality-filters {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1rem;
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modality-filter-group {
|
||||||
|
border: 1px solid #e2e8f0;
|
||||||
|
border-radius: 8px;
|
||||||
|
background: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modality-filter-title {
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: #4a5568;
|
||||||
|
background: #f8fafc;
|
||||||
|
border-radius: 7px 7px 0 0;
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
transition: background-color 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modality-filter-title:hover {
|
||||||
|
background: #edf2f7;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modality-filter-group[open] .modality-filter-title {
|
||||||
|
border-bottom: 1px solid #e2e8f0;
|
||||||
|
border-radius: 7px 7px 0 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.filters-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||||
|
gap: 1rem;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.filter-field {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.filter-label {
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 500;
|
||||||
|
color: #4a5568;
|
||||||
|
margin-bottom: 0.25rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.filter-input {
|
||||||
|
padding: 0.5rem;
|
||||||
|
border: 1px solid #e2e8f0;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
background: white;
|
||||||
|
transition: border-color 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.filter-input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #667eea;
|
||||||
|
box-shadow: 0 0 0 1px #667eea;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Selectable tags controls */
|
||||||
|
.selectable-tags-details {
|
||||||
|
border: 1px solid #e2e8f0;
|
||||||
|
border-radius: 8px;
|
||||||
|
background: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.selectable-tags-summary {
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: #4a5568;
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
transition: background-color 0.2s;
|
||||||
|
border-radius: 7px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.selectable-tags-summary:hover {
|
||||||
|
background: #f8fafc;
|
||||||
|
}
|
||||||
|
|
||||||
|
.selectable-tags-details[open] .selectable-tags-summary {
|
||||||
|
border-bottom: 1px solid #e2e8f0;
|
||||||
|
border-radius: 7px 7px 0 0;
|
||||||
|
background: #f8fafc;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-controls {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
border-bottom: 1px solid #e2e8f0;
|
||||||
|
background: #fafbfc;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-control-btn {
|
||||||
|
background: #f7fafc;
|
||||||
|
border: 1px solid #e2e8f0;
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 0.25rem 0.5rem;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: #4a5568;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-control-btn:hover:not(:disabled) {
|
||||||
|
background: #edf2f7;
|
||||||
|
border-color: #cbd5e0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-control-btn:disabled {
|
||||||
|
opacity: 0.5;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-count {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: #718096;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Tag search controls */
|
||||||
|
.tag-search-controls {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 1rem;
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
background: #f8fafc;
|
||||||
|
border-bottom: 1px solid #e2e8f0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-search-input {
|
||||||
|
flex: 1;
|
||||||
|
padding: 0.375rem 0.5rem;
|
||||||
|
border: 1px solid #e2e8f0;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
background: white;
|
||||||
|
transition: border-color 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-search-input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #667eea;
|
||||||
|
box-shadow: 0 0 0 1px #667eea;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
.filtered-count {
|
||||||
|
font-size: 0.75rem;
|
||||||
|
color: #718096;
|
||||||
|
font-weight: 500;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tags-display-area {
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 768px) {
|
||||||
|
.filters-grid {
|
||||||
|
grid-template-columns: 1fr;
|
||||||
|
gap: 0.5rem;
|
||||||
|
padding: 0.75rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modality-filter-title {
|
||||||
|
padding: 0.5rem 0.75rem;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tag-search-controls {
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: stretch;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
}
|
}
|
@ -2,9 +2,9 @@ import { useEffect } from 'react'
|
|||||||
import { BrowserRouter as Router, Routes, Route, Navigate, useNavigate, useLocation } from 'react-router-dom'
|
import { BrowserRouter as Router, Routes, Route, Navigate, useNavigate, useLocation } from 'react-router-dom'
|
||||||
import './App.css'
|
import './App.css'
|
||||||
|
|
||||||
import { useAuth } from './hooks/useAuth'
|
import { useAuth } from '@/hooks/useAuth'
|
||||||
import { useOAuth } from './hooks/useOAuth'
|
import { useOAuth } from '@/hooks/useOAuth'
|
||||||
import { Loading, LoginPrompt, AuthError, Dashboard, Search } from './components'
|
import { Loading, LoginPrompt, AuthError, Dashboard, Search } from '@/components'
|
||||||
|
|
||||||
// AuthWrapper handles redirects based on auth state
|
// AuthWrapper handles redirects based on auth state
|
||||||
const AuthWrapper = () => {
|
const AuthWrapper = () => {
|
||||||
@ -31,7 +31,10 @@ const AuthWrapper = () => {
|
|||||||
// Handle redirects based on auth state changes
|
// Handle redirects based on auth state changes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!isLoading) {
|
if (!isLoading) {
|
||||||
if (isAuthenticated) {
|
if (location.pathname === '/ui/logout') {
|
||||||
|
logout()
|
||||||
|
navigate('/ui/login', { replace: true })
|
||||||
|
} else if (isAuthenticated) {
|
||||||
// If authenticated and on login page, redirect to dashboard
|
// If authenticated and on login page, redirect to dashboard
|
||||||
if (location.pathname === '/ui/login' || location.pathname === '/ui') {
|
if (location.pathname === '/ui/login' || location.pathname === '/ui') {
|
||||||
navigate('/ui/dashboard', { replace: true })
|
navigate('/ui/dashboard', { replace: true })
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { Link } from 'react-router-dom'
|
import { Link } from 'react-router-dom'
|
||||||
import { useMCP } from '../hooks/useMCP'
|
import { useMCP } from '@/hooks/useMCP'
|
||||||
|
|
||||||
const Dashboard = ({ onLogout }) => {
|
const Dashboard = ({ onLogout }) => {
|
||||||
const { listNotes } = useMCP()
|
const { listNotes } = useMCP()
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import React from 'react'
|
const Loading = ({ message = "Loading..." }: { message?: string }) => {
|
||||||
|
|
||||||
const Loading = ({ message = "Loading..." }) => {
|
|
||||||
return (
|
return (
|
||||||
<div className="loading">
|
<div className="loading">
|
||||||
<h2>{message}</h2>
|
<h2>{message}</h2>
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import React from 'react'
|
const AuthError = ({ error, onRetry }: { error: string, onRetry: () => void }) => {
|
||||||
|
|
||||||
const AuthError = ({ error, onRetry }) => {
|
|
||||||
return (
|
return (
|
||||||
<div className="error">
|
<div className="error">
|
||||||
<h2>Authentication Error</h2>
|
<h2>Authentication Error</h2>
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import React from 'react'
|
const LoginPrompt = ({ onLogin }: { onLogin: () => void }) => {
|
||||||
|
|
||||||
const LoginPrompt = ({ onLogin }) => {
|
|
||||||
return (
|
return (
|
||||||
<div className="login-prompt">
|
<div className="login-prompt">
|
||||||
<h1>Memory App</h1>
|
<h1>Memory App</h1>
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
export { default as Loading } from './Loading'
|
export { default as Loading } from './Loading'
|
||||||
export { default as Dashboard } from './Dashboard'
|
export { default as Dashboard } from './Dashboard'
|
||||||
export { default as Search } from './Search'
|
export { default as Search } from './search'
|
||||||
export { default as LoginPrompt } from './auth/LoginPrompt'
|
export { default as LoginPrompt } from './auth/LoginPrompt'
|
||||||
export { default as AuthError } from './auth/AuthError'
|
export { default as AuthError } from './auth/AuthError'
|
155
frontend/src/components/search/DynamicFilters.tsx
Normal file
155
frontend/src/components/search/DynamicFilters.tsx
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import { FilterInput } from './FilterInput'
|
||||||
|
import { CollectionMetadata, SchemaArg } from '@/types/mcp'
|
||||||
|
|
||||||
|
// Pure helper functions for schema processing
|
||||||
|
const formatFieldLabel = (field: string): string =>
|
||||||
|
field.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase())
|
||||||
|
|
||||||
|
const shouldSkipField = (fieldName: string): boolean =>
|
||||||
|
['tags'].includes(fieldName)
|
||||||
|
|
||||||
|
const isCommonField = (fieldName: string): boolean =>
|
||||||
|
['size', 'filename', 'content_type', 'mail_message_id', 'sent_at', 'created_at', 'source_id'].includes(fieldName)
|
||||||
|
|
||||||
|
const createMinMaxFields = (fieldName: string, fieldConfig: any): [string, any][] => [
|
||||||
|
[`min_${fieldName}`, { ...fieldConfig, description: `Min ${fieldConfig.description}` }],
|
||||||
|
[`max_${fieldName}`, { ...fieldConfig, description: `Max ${fieldConfig.description}` }]
|
||||||
|
]
|
||||||
|
|
||||||
|
const createSizeFields = (): [string, any][] => [
|
||||||
|
['min_size', { type: 'int', description: 'Minimum size in bytes' }],
|
||||||
|
['max_size', { type: 'int', description: 'Maximum size in bytes' }]
|
||||||
|
]
|
||||||
|
|
||||||
|
const extractSchemaFields = (schema: Record<string, SchemaArg>, includeCommon = true): [string, SchemaArg][] =>
|
||||||
|
Object.entries(schema)
|
||||||
|
.filter(([fieldName]) => !shouldSkipField(fieldName))
|
||||||
|
.filter(([fieldName]) => includeCommon || !isCommonField(fieldName))
|
||||||
|
.flatMap(([fieldName, fieldConfig]) =>
|
||||||
|
['sent_at', 'published', 'created_at'].includes(fieldName)
|
||||||
|
? createMinMaxFields(fieldName, fieldConfig)
|
||||||
|
: [[fieldName, fieldConfig] as [string, SchemaArg]]
|
||||||
|
)
|
||||||
|
|
||||||
|
const getCommonFields = (schemas: Record<string, CollectionMetadata>, selectedModalities: string[]): [string, SchemaArg][] => {
|
||||||
|
const commonFieldsMap = new Map<string, SchemaArg>()
|
||||||
|
|
||||||
|
// Manually add created_at fields even if not in schema
|
||||||
|
createMinMaxFields('created_at', { type: 'datetime', description: 'Creation date' }).forEach(([field, config]) => {
|
||||||
|
commonFieldsMap.set(field, config)
|
||||||
|
})
|
||||||
|
|
||||||
|
selectedModalities.forEach(modality => {
|
||||||
|
const schema = schemas[modality].schema
|
||||||
|
if (!schema) return
|
||||||
|
|
||||||
|
Object.entries(schema).forEach(([fieldName, fieldConfig]) => {
|
||||||
|
if (isCommonField(fieldName)) {
|
||||||
|
if (['sent_at', 'created_at'].includes(fieldName)) {
|
||||||
|
createMinMaxFields(fieldName, fieldConfig).forEach(([field, config]) => {
|
||||||
|
commonFieldsMap.set(field, config)
|
||||||
|
})
|
||||||
|
} else if (fieldName === 'size') {
|
||||||
|
createSizeFields().forEach(([field, config]) => {
|
||||||
|
commonFieldsMap.set(field, config)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
commonFieldsMap.set(fieldName, fieldConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return Array.from(commonFieldsMap.entries())
|
||||||
|
}
|
||||||
|
|
||||||
|
const getModalityFields = (schemas: Record<string, CollectionMetadata>, selectedModalities: string[]): Record<string, [string, SchemaArg][]> => {
|
||||||
|
return selectedModalities.reduce((acc, modality) => {
|
||||||
|
const schema = schemas[modality].schema
|
||||||
|
if (!schema) return acc
|
||||||
|
|
||||||
|
const schemaFields = extractSchemaFields(schema, false) // Exclude common fields
|
||||||
|
|
||||||
|
if (schemaFields.length > 0) {
|
||||||
|
acc[modality] = schemaFields
|
||||||
|
}
|
||||||
|
|
||||||
|
return acc
|
||||||
|
}, {} as Record<string, [string, SchemaArg][]>)
|
||||||
|
}
|
||||||
|
|
||||||
|
interface DynamicFiltersProps {
|
||||||
|
schemas: Record<string, CollectionMetadata>
|
||||||
|
selectedModalities: string[]
|
||||||
|
filters: Record<string, SchemaArg>
|
||||||
|
onFilterChange: (field: string, value: SchemaArg) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export const DynamicFilters = ({
|
||||||
|
schemas,
|
||||||
|
selectedModalities,
|
||||||
|
filters,
|
||||||
|
onFilterChange
|
||||||
|
}: DynamicFiltersProps) => {
|
||||||
|
const commonFields = getCommonFields(schemas, selectedModalities)
|
||||||
|
const modalityFields = getModalityFields(schemas, selectedModalities)
|
||||||
|
|
||||||
|
if (commonFields.length === 0 && Object.keys(modalityFields).length === 0) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="search-option">
|
||||||
|
<label>Filters:</label>
|
||||||
|
<div className="modality-filters">
|
||||||
|
{/* Common/Document Properties Section */}
|
||||||
|
{commonFields.length > 0 && (
|
||||||
|
<details className="modality-filter-group" open>
|
||||||
|
<summary className="modality-filter-title">
|
||||||
|
Document Properties
|
||||||
|
</summary>
|
||||||
|
<div className="filters-grid">
|
||||||
|
{commonFields.map(([field, fieldConfig]: [string, SchemaArg]) => (
|
||||||
|
<div key={field} className="filter-field">
|
||||||
|
<label className="filter-label">
|
||||||
|
{formatFieldLabel(field)}:
|
||||||
|
</label>
|
||||||
|
<FilterInput
|
||||||
|
field={field}
|
||||||
|
fieldConfig={fieldConfig}
|
||||||
|
value={filters[field]}
|
||||||
|
onChange={onFilterChange}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Modality-specific sections */}
|
||||||
|
{Object.entries(modalityFields).map(([modality, fields]) => (
|
||||||
|
<details key={modality} className="modality-filter-group">
|
||||||
|
<summary className="modality-filter-title">
|
||||||
|
{formatFieldLabel(modality)} Specific
|
||||||
|
</summary>
|
||||||
|
<div className="filters-grid">
|
||||||
|
{fields.map(([field, fieldConfig]: [string, SchemaArg]) => (
|
||||||
|
<div key={field} className="filter-field">
|
||||||
|
<label className="filter-label">
|
||||||
|
{formatFieldLabel(field)}:
|
||||||
|
</label>
|
||||||
|
<FilterInput
|
||||||
|
field={field}
|
||||||
|
fieldConfig={fieldConfig}
|
||||||
|
value={filters[field]}
|
||||||
|
onChange={onFilterChange}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
53
frontend/src/components/search/FilterInput.tsx
Normal file
53
frontend/src/components/search/FilterInput.tsx
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
// Pure helper functions
|
||||||
|
const isDateField = (field: string): boolean =>
|
||||||
|
field.includes('sent_at') || field.includes('published') || field.includes('created_at')
|
||||||
|
|
||||||
|
const isNumberField = (field: string, fieldConfig: any): boolean =>
|
||||||
|
fieldConfig.type?.includes('int') || field.includes('size')
|
||||||
|
|
||||||
|
const parseNumberValue = (value: string): number | null =>
|
||||||
|
value ? parseInt(value) : null
|
||||||
|
|
||||||
|
interface FilterInputProps {
|
||||||
|
field: string
|
||||||
|
fieldConfig: any
|
||||||
|
value: any
|
||||||
|
onChange: (field: string, value: any) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export const FilterInput = ({ field, fieldConfig, value, onChange }: FilterInputProps) => {
|
||||||
|
const inputProps = {
|
||||||
|
value: value || '',
|
||||||
|
className: "filter-input"
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isNumberField(field, fieldConfig)) {
|
||||||
|
return (
|
||||||
|
<input
|
||||||
|
{...inputProps}
|
||||||
|
type="number"
|
||||||
|
onChange={(e) => onChange(field, parseNumberValue(e.target.value))}
|
||||||
|
placeholder={fieldConfig.description}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isDateField(field)) {
|
||||||
|
return (
|
||||||
|
<input
|
||||||
|
{...inputProps}
|
||||||
|
type="date"
|
||||||
|
onChange={(e) => onChange(field, e.target.value || null)}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<input
|
||||||
|
{...inputProps}
|
||||||
|
type="text"
|
||||||
|
onChange={(e) => onChange(field, e.target.value || null)}
|
||||||
|
placeholder={fieldConfig.description}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
69
frontend/src/components/search/Search.tsx
Normal file
69
frontend/src/components/search/Search.tsx
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import { useState } from 'react'
|
||||||
|
import { useNavigate } from 'react-router-dom'
|
||||||
|
import { useMCP } from '@/hooks/useMCP'
|
||||||
|
import Loading from '@/components/Loading'
|
||||||
|
import { SearchResult } from './results'
|
||||||
|
import SearchForm, { SearchParams } from './SearchForm'
|
||||||
|
|
||||||
|
const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => {
|
||||||
|
if (isLoading) {
|
||||||
|
return <Loading message="Searching..." />
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<div className="search-results">
|
||||||
|
{results.length > 0 && (
|
||||||
|
<div className="results-count">
|
||||||
|
Found {results.length} result{results.length !== 1 ? 's' : ''}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{results.map((result, index) => <SearchResult key={index} result={result} />)}
|
||||||
|
|
||||||
|
{results.length === 0 && (
|
||||||
|
<div className="no-results">
|
||||||
|
No results found
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const Search = () => {
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const [results, setResults] = useState([])
|
||||||
|
const [isLoading, setIsLoading] = useState(false)
|
||||||
|
const { searchKnowledgeBase } = useMCP()
|
||||||
|
|
||||||
|
const handleSearch = async (params: SearchParams) => {
|
||||||
|
if (!params.query.trim()) return
|
||||||
|
|
||||||
|
setIsLoading(true)
|
||||||
|
try {
|
||||||
|
console.log(params)
|
||||||
|
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
|
||||||
|
setResults(searchResults || [])
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Search error:', error)
|
||||||
|
setResults([])
|
||||||
|
} finally {
|
||||||
|
setIsLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="search-view">
|
||||||
|
<header className="search-header">
|
||||||
|
<button onClick={() => navigate('/ui/dashboard')} className="back-btn">
|
||||||
|
← Back to Dashboard
|
||||||
|
</button>
|
||||||
|
<h2>🔍 Search Knowledge Base</h2>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<SearchForm isLoading={isLoading} onSearch={handleSearch} />
|
||||||
|
<SearchResults results={results} isLoading={isLoading} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default Search
|
151
frontend/src/components/search/SearchForm.tsx
Normal file
151
frontend/src/components/search/SearchForm.tsx
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
import { useMCP } from '@/hooks/useMCP'
|
||||||
|
import { useEffect, useState } from 'react'
|
||||||
|
import { DynamicFilters } from './DynamicFilters'
|
||||||
|
import { SelectableTags } from './SelectableTags'
|
||||||
|
import { CollectionMetadata } from '@/types/mcp'
|
||||||
|
|
||||||
|
type Filter = {
|
||||||
|
tags?: string[]
|
||||||
|
source_ids?: string[]
|
||||||
|
[key: string]: any
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SearchParams {
|
||||||
|
query: string
|
||||||
|
previews: boolean
|
||||||
|
modalities: string[]
|
||||||
|
filters: Filter
|
||||||
|
limit: number
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SearchFormProps {
|
||||||
|
isLoading: boolean
|
||||||
|
onSearch: (params: SearchParams) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Pure helper functions for SearchForm
|
||||||
|
const createFlags = (items: string[], defaultValue = false): Record<string, boolean> =>
|
||||||
|
items.reduce((acc, item) => ({ ...acc, [item]: defaultValue }), {})
|
||||||
|
|
||||||
|
const getSelectedItems = (items: Record<string, boolean>): string[] =>
|
||||||
|
Object.entries(items).filter(([_, selected]) => selected).map(([key]) => key)
|
||||||
|
|
||||||
|
const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
|
||||||
|
Object.entries(filters)
|
||||||
|
.filter(([_, value]) => value !== null && value !== '' && value !== undefined)
|
||||||
|
.reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {})
|
||||||
|
|
||||||
|
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||||
|
const [query, setQuery] = useState('')
|
||||||
|
const [previews, setPreviews] = useState(false)
|
||||||
|
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
||||||
|
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
||||||
|
const [tags, setTags] = useState<Record<string, boolean>>({})
|
||||||
|
const [dynamicFilters, setDynamicFilters] = useState<Record<string, any>>({})
|
||||||
|
const [limit, setLimit] = useState(10)
|
||||||
|
const { getMetadataSchemas, getTags } = useMCP()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const setupFilters = async () => {
|
||||||
|
const [schemas, tags] = await Promise.all([
|
||||||
|
getMetadataSchemas(),
|
||||||
|
getTags()
|
||||||
|
])
|
||||||
|
setSchemas(schemas)
|
||||||
|
setModalities(createFlags(Object.keys(schemas), true))
|
||||||
|
setTags(createFlags(tags))
|
||||||
|
}
|
||||||
|
setupFilters()
|
||||||
|
}, [getMetadataSchemas, getTags])
|
||||||
|
|
||||||
|
const handleFilterChange = (field: string, value: any) =>
|
||||||
|
setDynamicFilters(prev => ({ ...prev, [field]: value }))
|
||||||
|
|
||||||
|
const handleSubmit = (e: React.FormEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
|
||||||
|
onSearch({
|
||||||
|
query,
|
||||||
|
previews,
|
||||||
|
modalities: getSelectedItems(modalities),
|
||||||
|
filters: {
|
||||||
|
tags: getSelectedItems(tags),
|
||||||
|
...cleanFilters(dynamicFilters)
|
||||||
|
},
|
||||||
|
limit
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form onSubmit={handleSubmit} className="search-form">
|
||||||
|
<div className="search-input-group">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={query}
|
||||||
|
onChange={(e) => setQuery(e.target.value)}
|
||||||
|
placeholder="Search your knowledge base..."
|
||||||
|
className="search-input"
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
<button type="submit" disabled={isLoading} className="search-btn">
|
||||||
|
{isLoading ? 'Searching...' : 'Search'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="search-options">
|
||||||
|
<div className="search-option">
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={previews}
|
||||||
|
onChange={(e) => setPreviews(e.target.checked)}
|
||||||
|
/>
|
||||||
|
Include content previews
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<SelectableTags
|
||||||
|
title="Modalities"
|
||||||
|
className="modality-checkboxes"
|
||||||
|
tags={modalities}
|
||||||
|
onSelect={(tag, selected) => setModalities({ ...modalities, [tag]: selected })}
|
||||||
|
onBatchUpdate={(updates) => setModalities(updates)}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SelectableTags
|
||||||
|
title="Tags"
|
||||||
|
className="tags-container"
|
||||||
|
tags={tags}
|
||||||
|
onSelect={(tag, selected) => setTags({ ...tags, [tag]: selected })}
|
||||||
|
onBatchUpdate={(updates) => setTags(updates)}
|
||||||
|
searchable={true}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<DynamicFilters
|
||||||
|
schemas={schemas}
|
||||||
|
selectedModalities={getSelectedItems(modalities)}
|
||||||
|
filters={dynamicFilters}
|
||||||
|
onFilterChange={handleFilterChange}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div className="search-option">
|
||||||
|
<label>
|
||||||
|
Max Results:
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
value={limit}
|
||||||
|
onChange={(e) => setLimit(parseInt(e.target.value) || 10)}
|
||||||
|
min={1}
|
||||||
|
max={100}
|
||||||
|
className="limit-input"
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default SearchForm
|
136
frontend/src/components/search/SelectableTags.tsx
Normal file
136
frontend/src/components/search/SelectableTags.tsx
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
interface SelectableTagProps {
|
||||||
|
tag: string
|
||||||
|
selected: boolean
|
||||||
|
onSelect: (tag: string, selected: boolean) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const SelectableTag = ({ tag, selected, onSelect }: SelectableTagProps) => {
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
className={`tag ${selected ? 'selected' : ''}`}
|
||||||
|
onClick={() => onSelect(tag, !selected)}
|
||||||
|
>
|
||||||
|
{tag}
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
import { useState } from 'react'
|
||||||
|
|
||||||
|
interface SelectableTagsProps {
|
||||||
|
title: string
|
||||||
|
className: string
|
||||||
|
tags: Record<string, boolean>
|
||||||
|
onSelect: (tag: string, selected: boolean) => void
|
||||||
|
onBatchUpdate?: (updates: Record<string, boolean>) => void
|
||||||
|
searchable?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SelectableTags = ({ title, className, tags, onSelect, onBatchUpdate, searchable = false }: SelectableTagsProps) => {
|
||||||
|
const [searchTerm, setSearchTerm] = useState('')
|
||||||
|
const handleSelectAll = () => {
|
||||||
|
if (onBatchUpdate) {
|
||||||
|
const updates = Object.keys(tags).reduce((acc, tag) => {
|
||||||
|
acc[tag] = true
|
||||||
|
return acc
|
||||||
|
}, {} as Record<string, boolean>)
|
||||||
|
onBatchUpdate(updates)
|
||||||
|
} else {
|
||||||
|
// Fallback to individual updates (though this won't work well)
|
||||||
|
Object.keys(tags).forEach(tag => {
|
||||||
|
if (!tags[tag]) {
|
||||||
|
onSelect(tag, true)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDeselectAll = () => {
|
||||||
|
if (onBatchUpdate) {
|
||||||
|
const updates = Object.keys(tags).reduce((acc, tag) => {
|
||||||
|
acc[tag] = false
|
||||||
|
return acc
|
||||||
|
}, {} as Record<string, boolean>)
|
||||||
|
onBatchUpdate(updates)
|
||||||
|
} else {
|
||||||
|
// Fallback to individual updates (though this won't work well)
|
||||||
|
Object.keys(tags).forEach(tag => {
|
||||||
|
if (tags[tag]) {
|
||||||
|
onSelect(tag, false)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter tags based on search term
|
||||||
|
const filteredTags = Object.entries(tags).filter(([tag, selected]) => {
|
||||||
|
return !searchTerm || tag.toLowerCase().includes(searchTerm.toLowerCase())
|
||||||
|
})
|
||||||
|
|
||||||
|
const selectedCount = Object.values(tags).filter(Boolean).length
|
||||||
|
const totalCount = Object.keys(tags).length
|
||||||
|
const filteredSelectedCount = filteredTags.filter(([_, selected]) => selected).length
|
||||||
|
const filteredTotalCount = filteredTags.length
|
||||||
|
const allSelected = selectedCount === totalCount
|
||||||
|
const noneSelected = selectedCount === 0
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="search-option">
|
||||||
|
<details className="selectable-tags-details">
|
||||||
|
<summary className="selectable-tags-summary">
|
||||||
|
{title} ({selectedCount} selected)
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
<div className="tag-controls">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
className="tag-control-btn"
|
||||||
|
onClick={handleSelectAll}
|
||||||
|
disabled={allSelected}
|
||||||
|
>
|
||||||
|
All
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
className="tag-control-btn"
|
||||||
|
onClick={handleDeselectAll}
|
||||||
|
disabled={noneSelected}
|
||||||
|
>
|
||||||
|
None
|
||||||
|
</button>
|
||||||
|
<span className="tag-count">
|
||||||
|
({selectedCount}/{totalCount})
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{searchable && (
|
||||||
|
<div className="tag-search-controls">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
placeholder={`Search ${title.toLowerCase()}...`}
|
||||||
|
value={searchTerm}
|
||||||
|
onChange={(e) => setSearchTerm(e.target.value)}
|
||||||
|
className="tag-search-input"
|
||||||
|
/>
|
||||||
|
{searchTerm && (
|
||||||
|
<span className="filtered-count">
|
||||||
|
Showing {filteredSelectedCount}/{filteredTotalCount}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className={`${className} tags-display-area`}>
|
||||||
|
{filteredTags.map(([tag, selected]: [string, boolean]) => (
|
||||||
|
<SelectableTag
|
||||||
|
key={tag}
|
||||||
|
tag={tag}
|
||||||
|
selected={selected}
|
||||||
|
onSelect={onSelect}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
3
frontend/src/components/search/index.js
Normal file
3
frontend/src/components/search/index.js
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
import Search from './Search'
|
||||||
|
|
||||||
|
export default Search
|
@ -1,10 +1,8 @@
|
|||||||
import React, { useState, useEffect } from 'react'
|
import { useState, useEffect } from 'react'
|
||||||
import { useNavigate } from 'react-router-dom'
|
|
||||||
import ReactMarkdown from 'react-markdown'
|
import ReactMarkdown from 'react-markdown'
|
||||||
import { useMCP } from '../hooks/useMCP'
|
import { useMCP } from '@/hooks/useMCP'
|
||||||
import Loading from './Loading'
|
|
||||||
|
|
||||||
type SearchItem = {
|
export type SearchItem = {
|
||||||
filename: string
|
filename: string
|
||||||
content: string
|
content: string
|
||||||
chunks: any[]
|
chunks: any[]
|
||||||
@ -13,7 +11,7 @@ type SearchItem = {
|
|||||||
metadata: any
|
metadata: any
|
||||||
}
|
}
|
||||||
|
|
||||||
const Tag = ({ tags }: { tags: string[] }) => {
|
export const Tag = ({ tags }: { tags: string[] }) => {
|
||||||
return (
|
return (
|
||||||
<div className="tags">
|
<div className="tags">
|
||||||
{tags?.map((tag: string, index: number) => (
|
{tags?.map((tag: string, index: number) => (
|
||||||
@ -23,11 +21,12 @@ const Tag = ({ tags }: { tags: string[] }) => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const TextResult = ({ filename, content, chunks, tags }: SearchItem) => {
|
export const TextResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
||||||
return (
|
return (
|
||||||
<div className="search-result-card">
|
<div className="search-result-card">
|
||||||
<h4>{filename || 'Untitled'}</h4>
|
<h4>{filename || 'Untitled'}</h4>
|
||||||
<Tag tags={tags} />
|
<Tag tags={tags} />
|
||||||
|
<Metadata metadata={metadata} />
|
||||||
<p className="result-content">{content || 'No content available'}</p>
|
<p className="result-content">{content || 'No content available'}</p>
|
||||||
{chunks && chunks.length > 0 && (
|
{chunks && chunks.length > 0 && (
|
||||||
<details className="result-chunks">
|
<details className="result-chunks">
|
||||||
@ -44,7 +43,7 @@ const TextResult = ({ filename, content, chunks, tags }: SearchItem) => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
export const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
||||||
return (
|
return (
|
||||||
<div className="search-result-card">
|
<div className="search-result-card">
|
||||||
<h4>{filename || 'Untitled'}</h4>
|
<h4>{filename || 'Untitled'}</h4>
|
||||||
@ -70,7 +69,7 @@ const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchIte
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => {
|
export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
||||||
const title = metadata?.title || filename || 'Untitled'
|
const title = metadata?.title || filename || 'Untitled'
|
||||||
const { fetchFile } = useMCP()
|
const { fetchFile } = useMCP()
|
||||||
const [mime_type, setMimeType] = useState<string>()
|
const [mime_type, setMimeType] = useState<string>()
|
||||||
@ -95,7 +94,7 @@ const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const Metadata = ({ metadata }: { metadata: any }) => {
|
export const Metadata = ({ metadata }: { metadata: any }) => {
|
||||||
if (!metadata) return null
|
if (!metadata) return null
|
||||||
return (
|
return (
|
||||||
<div className="metadata">
|
<div className="metadata">
|
||||||
@ -108,7 +107,7 @@ const Metadata = ({ metadata }: { metadata: any }) => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
export const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
||||||
return (
|
return (
|
||||||
<div className="search-result-card">
|
<div className="search-result-card">
|
||||||
<h4>{filename || 'Untitled'}</h4>
|
<h4>{filename || 'Untitled'}</h4>
|
||||||
@ -125,7 +124,7 @@ const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
||||||
return (
|
return (
|
||||||
<div className="search-result-card">
|
<div className="search-result-card">
|
||||||
<h4>{metadata?.title || metadata?.subject || 'Untitled'}</h4>
|
<h4>{metadata?.title || metadata?.subject || 'Untitled'}</h4>
|
||||||
@ -138,7 +137,7 @@ const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const SearchResult = ({ result }: { result: SearchItem }) => {
|
export const SearchResult = ({ result }: { result: SearchItem }) => {
|
||||||
if (result.mime_type.startsWith('image/')) {
|
if (result.mime_type.startsWith('image/')) {
|
||||||
return <ImageResult {...result} />
|
return <ImageResult {...result} />
|
||||||
}
|
}
|
||||||
@ -158,86 +157,4 @@ const SearchResult = ({ result }: { result: SearchItem }) => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => {
|
export default SearchResult
|
||||||
if (isLoading) {
|
|
||||||
return <Loading message="Searching..." />
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<div className="search-results">
|
|
||||||
{results.length > 0 && (
|
|
||||||
<div className="results-count">
|
|
||||||
Found {results.length} result{results.length !== 1 ? 's' : ''}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{results.map((result, index) => <SearchResult key={index} result={result} />)}
|
|
||||||
|
|
||||||
{results.length === 0 && (
|
|
||||||
<div className="no-results">
|
|
||||||
No results found
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const SearchForm = ({ isLoading, onSearch }: { isLoading: boolean, onSearch: (query: string) => void }) => {
|
|
||||||
const [query, setQuery] = useState('')
|
|
||||||
return (
|
|
||||||
<form onSubmit={(e) => {
|
|
||||||
e.preventDefault()
|
|
||||||
onSearch(query)
|
|
||||||
}} className="search-form">
|
|
||||||
<div className="search-input-group">
|
|
||||||
<input
|
|
||||||
type="text"
|
|
||||||
value={query}
|
|
||||||
onChange={(e) => setQuery(e.target.value)}
|
|
||||||
placeholder="Search your knowledge base..."
|
|
||||||
className="search-input"
|
|
||||||
/>
|
|
||||||
<button type="submit" disabled={isLoading} className="search-btn">
|
|
||||||
{isLoading ? 'Searching...' : 'Search'}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const Search = () => {
|
|
||||||
const navigate = useNavigate()
|
|
||||||
const [results, setResults] = useState([])
|
|
||||||
const [isLoading, setIsLoading] = useState(false)
|
|
||||||
const { searchKnowledgeBase } = useMCP()
|
|
||||||
|
|
||||||
const handleSearch = async (query: string) => {
|
|
||||||
if (!query.trim()) return
|
|
||||||
|
|
||||||
setIsLoading(true)
|
|
||||||
try {
|
|
||||||
const searchResults = await searchKnowledgeBase(query)
|
|
||||||
setResults(searchResults || [])
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Search error:', error)
|
|
||||||
setResults([])
|
|
||||||
} finally {
|
|
||||||
setIsLoading(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="search-view">
|
|
||||||
<header className="search-header">
|
|
||||||
<button onClick={() => navigate('/ui/dashboard')} className="back-btn">
|
|
||||||
← Back to Dashboard
|
|
||||||
</button>
|
|
||||||
<h2>🔍 Search Knowledge Base</h2>
|
|
||||||
</header>
|
|
||||||
|
|
||||||
<SearchForm isLoading={isLoading} onSearch={handleSearch} />
|
|
||||||
<SearchResults results={results} isLoading={isLoading} />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
export default Search
|
|
@ -4,20 +4,20 @@ const SERVER_URL = import.meta.env.VITE_SERVER_URL || 'http://localhost:8000'
|
|||||||
const SESSION_COOKIE_NAME = import.meta.env.VITE_SESSION_COOKIE_NAME || 'session_id'
|
const SESSION_COOKIE_NAME = import.meta.env.VITE_SESSION_COOKIE_NAME || 'session_id'
|
||||||
|
|
||||||
// Cookie utilities
|
// Cookie utilities
|
||||||
const getCookie = (name) => {
|
const getCookie = (name: string) => {
|
||||||
const value = `; ${document.cookie}`
|
const value = `; ${document.cookie}`
|
||||||
const parts = value.split(`; ${name}=`)
|
const parts = value.split(`; ${name}=`)
|
||||||
if (parts.length === 2) return parts.pop().split(';').shift()
|
if (parts.length === 2) return parts.pop().split(';').shift()
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
const setCookie = (name, value, days = 30) => {
|
const setCookie = (name: string, value: string, days = 30) => {
|
||||||
const expires = new Date()
|
const expires = new Date()
|
||||||
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
|
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
|
||||||
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
|
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleteCookie = (name) => {
|
const deleteCookie = (name: string) => {
|
||||||
document.cookie = `${name}=;expires=Thu, 01 Jan 1970 00:00:01 GMT;path=/`
|
document.cookie = `${name}=;expires=Thu, 01 Jan 1970 00:00:01 GMT;path=/`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,6 +68,7 @@ export const useAuth = () => {
|
|||||||
deleteCookie('access_token')
|
deleteCookie('access_token')
|
||||||
deleteCookie('refresh_token')
|
deleteCookie('refresh_token')
|
||||||
deleteCookie(SESSION_COOKIE_NAME)
|
deleteCookie(SESSION_COOKIE_NAME)
|
||||||
|
localStorage.removeItem('oauth_client_id')
|
||||||
setIsAuthenticated(false)
|
setIsAuthenticated(false)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
@ -110,7 +111,7 @@ export const useAuth = () => {
|
|||||||
}, [logout])
|
}, [logout])
|
||||||
|
|
||||||
// Make authenticated API calls with automatic token refresh
|
// Make authenticated API calls with automatic token refresh
|
||||||
const apiCall = useCallback(async (endpoint, options = {}) => {
|
const apiCall = useCallback(async (endpoint: string, options: RequestInit = {}) => {
|
||||||
let accessToken = getCookie('access_token')
|
let accessToken = getCookie('access_token')
|
||||||
|
|
||||||
if (!accessToken) {
|
if (!accessToken) {
|
||||||
@ -122,7 +123,7 @@ export const useAuth = () => {
|
|||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
const requestOptions = {
|
const requestOptions: RequestInit & { headers: Record<string, string> } = {
|
||||||
...options,
|
...options,
|
||||||
headers: { ...defaultHeaders, ...options.headers },
|
headers: { ...defaultHeaders, ...options.headers },
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { useEffect, useCallback } from 'react'
|
import { useEffect, useCallback } from 'react'
|
||||||
import { useAuth } from './useAuth'
|
import { useAuth } from '@/hooks/useAuth'
|
||||||
|
|
||||||
const parseServerSentEvents = async (response: Response): Promise<any> => {
|
const parseServerSentEvents = async (response: Response): Promise<any> => {
|
||||||
const reader = response.body?.getReader()
|
const reader = response.body?.getReader()
|
||||||
@ -91,10 +91,10 @@ const parseJsonRpcResponse = async (response: Response): Promise<any> => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const useMCP = () => {
|
export const useMCP = () => {
|
||||||
const { apiCall, isAuthenticated, isLoading, checkAuth } = useAuth()
|
const { apiCall, checkAuth } = useAuth()
|
||||||
|
|
||||||
const mcpCall = useCallback(async (path: string, method: string, params: any = {}) => {
|
const mcpCall = useCallback(async (method: string, params: any = {}) => {
|
||||||
const response = await apiCall(`/mcp${path}`, {
|
const response = await apiCall(`/mcp/${method}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Accept': 'application/json, text/event-stream',
|
'Accept': 'application/json, text/event-stream',
|
||||||
@ -118,22 +118,46 @@ export const useMCP = () => {
|
|||||||
if (resp?.result?.isError) {
|
if (resp?.result?.isError) {
|
||||||
throw new Error(resp?.result?.content[0].text)
|
throw new Error(resp?.result?.content[0].text)
|
||||||
}
|
}
|
||||||
return resp?.result?.content.map((item: any) => JSON.parse(item.text))
|
return resp?.result?.content.map((item: any) => {
|
||||||
|
try {
|
||||||
|
return JSON.parse(item.text)
|
||||||
|
} catch (e) {
|
||||||
|
return item.text
|
||||||
|
}
|
||||||
|
})
|
||||||
}, [apiCall])
|
}, [apiCall])
|
||||||
|
|
||||||
const listNotes = useCallback(async (path: string = "/") => {
|
const listNotes = useCallback(async (path: string = "/") => {
|
||||||
return await mcpCall('/note_files', 'note_files', { path })
|
return await mcpCall('note_files', { path })
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
const fetchFile = useCallback(async (filename: string) => {
|
const fetchFile = useCallback(async (filename: string) => {
|
||||||
return await mcpCall('/fetch_file', 'fetch_file', { filename })
|
return await mcpCall('fetch_file', { filename })
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10) => {
|
const getTags = useCallback(async () => {
|
||||||
return await mcpCall('/search_knowledge_base', 'search_knowledge_base', {
|
return await mcpCall('get_all_tags')
|
||||||
|
}, [mcpCall])
|
||||||
|
|
||||||
|
const getSubjects = useCallback(async () => {
|
||||||
|
return await mcpCall('get_all_subjects')
|
||||||
|
}, [mcpCall])
|
||||||
|
|
||||||
|
const getObservationTypes = useCallback(async () => {
|
||||||
|
return await mcpCall('get_all_observation_types')
|
||||||
|
}, [mcpCall])
|
||||||
|
|
||||||
|
const getMetadataSchemas = useCallback(async () => {
|
||||||
|
return (await mcpCall('get_metadata_schemas'))[0]
|
||||||
|
}, [mcpCall])
|
||||||
|
|
||||||
|
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
|
||||||
|
return await mcpCall('search_knowledge_base', {
|
||||||
query,
|
query,
|
||||||
|
filters,
|
||||||
|
modalities,
|
||||||
previews,
|
previews,
|
||||||
limit
|
limit,
|
||||||
})
|
})
|
||||||
}, [mcpCall])
|
}, [mcpCall])
|
||||||
|
|
||||||
@ -146,5 +170,9 @@ export const useMCP = () => {
|
|||||||
fetchFile,
|
fetchFile,
|
||||||
listNotes,
|
listNotes,
|
||||||
searchKnowledgeBase,
|
searchKnowledgeBase,
|
||||||
|
getTags,
|
||||||
|
getSubjects,
|
||||||
|
getObservationTypes,
|
||||||
|
getMetadataSchemas,
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -14,7 +14,7 @@ const generateCodeVerifier = () => {
|
|||||||
.replace(/=/g, '')
|
.replace(/=/g, '')
|
||||||
}
|
}
|
||||||
|
|
||||||
const generateCodeChallenge = async (verifier) => {
|
const generateCodeChallenge = async (verifier: string) => {
|
||||||
const data = new TextEncoder().encode(verifier)
|
const data = new TextEncoder().encode(verifier)
|
||||||
const digest = await crypto.subtle.digest('SHA-256', data)
|
const digest = await crypto.subtle.digest('SHA-256', data)
|
||||||
return btoa(String.fromCharCode(...new Uint8Array(digest)))
|
return btoa(String.fromCharCode(...new Uint8Array(digest)))
|
||||||
@ -33,7 +33,7 @@ const generateState = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Storage utilities
|
// Storage utilities
|
||||||
const setCookie = (name, value, days = 30) => {
|
const setCookie = (name: string, value: string, days = 30) => {
|
||||||
const expires = new Date()
|
const expires = new Date()
|
||||||
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
|
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
|
||||||
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
|
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { StrictMode } from 'react'
|
import { StrictMode } from 'react'
|
||||||
import { createRoot } from 'react-dom/client'
|
import { createRoot } from 'react-dom/client'
|
||||||
import App from './App.jsx'
|
import App from '@/App.jsx'
|
||||||
|
|
||||||
createRoot(document.getElementById('root')).render(
|
createRoot(document.getElementById('root')).render(
|
||||||
<StrictMode>
|
<StrictMode>
|
||||||
|
9
frontend/src/types/mcp.tsx
Normal file
9
frontend/src/types/mcp.tsx
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
export type CollectionMetadata = {
|
||||||
|
schema: Record<string, SchemaArg>
|
||||||
|
size: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export type SchemaArg = {
|
||||||
|
type: string
|
||||||
|
description: string
|
||||||
|
}
|
40
frontend/tsconfig.json
Normal file
40
frontend/tsconfig.json
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ES2020",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": [
|
||||||
|
"ES2020",
|
||||||
|
"DOM",
|
||||||
|
"DOM.Iterable"
|
||||||
|
],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
/* Bundler mode */
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
"isolatedModules": true,
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
/* Linting */
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": true,
|
||||||
|
"noUnusedParameters": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
/* Absolute imports */
|
||||||
|
"baseUrl": ".",
|
||||||
|
"paths": {
|
||||||
|
"@/*": [
|
||||||
|
"src/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"include": [
|
||||||
|
"src"
|
||||||
|
],
|
||||||
|
"references": [
|
||||||
|
{
|
||||||
|
"path": "./tsconfig.node.json"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
12
frontend/tsconfig.node.json
Normal file
12
frontend/tsconfig.node.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"composite": true,
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"module": "ESNext",
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowSyntheticDefaultImports": true
|
||||||
|
},
|
||||||
|
"include": [
|
||||||
|
"vite.config.js"
|
||||||
|
]
|
||||||
|
}
|
@ -1,8 +1,14 @@
|
|||||||
import { defineConfig } from 'vite'
|
import { defineConfig } from 'vite'
|
||||||
import react from '@vitejs/plugin-react'
|
import react from '@vitejs/plugin-react'
|
||||||
|
import path from 'path'
|
||||||
|
|
||||||
// https://vite.dev/config/
|
// https://vite.dev/config/
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
plugins: [react()],
|
plugins: [react()],
|
||||||
base: '/ui/',
|
base: '/ui/',
|
||||||
|
resolve: {
|
||||||
|
alias: {
|
||||||
|
'@': path.resolve(__dirname, './src'),
|
||||||
|
},
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
import memory.api.MCP.manifest
|
import memory.api.MCP.manifest
|
||||||
import memory.api.MCP.memory
|
import memory.api.MCP.memory
|
||||||
|
import memory.api.MCP.metadata
|
||||||
|
@ -127,7 +127,6 @@ async def handle_login(request: Request):
|
|||||||
return login_form(request, oauth_params, "Invalid email or password")
|
return login_form(request, oauth_params, "Invalid email or password")
|
||||||
|
|
||||||
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
||||||
print("redirect_url", redirect_url)
|
|
||||||
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||||
redirect_url = redirect_url.replace("http://", "cursor://")
|
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||||
return RedirectResponse(url=redirect_url, status_code=302)
|
return RedirectResponse(url=redirect_url, status_code=302)
|
||||||
|
@ -3,23 +3,25 @@ MCP tools for the epistemic sparring partner system.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import mimetypes
|
||||||
import pathlib
|
import pathlib
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import mimetypes
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Text, func
|
from sqlalchemy import Text
|
||||||
from sqlalchemy import cast as sql_cast
|
from sqlalchemy import cast as sql_cast
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
|
|
||||||
|
from memory.api.MCP.tools import mcp
|
||||||
from memory.api.search.search import SearchFilters, search
|
from memory.api.search.search import SearchFilters, search
|
||||||
from memory.common import extract, settings
|
from memory.common import extract, settings
|
||||||
|
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||||
|
from memory.common.celery_app import app as celery_app
|
||||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import AgentObservation, SourceItem
|
from memory.common.db.models import SourceItem, AgentObservation
|
||||||
from memory.common.formatters import observation
|
from memory.common.formatters import observation
|
||||||
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
|
|
||||||
from memory.api.MCP.tools import mcp
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -47,11 +49,13 @@ def filter_observation_source_ids(
|
|||||||
return source_ids
|
return source_ids
|
||||||
|
|
||||||
|
|
||||||
def filter_source_ids(
|
def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] | None:
|
||||||
modalities: set[str],
|
if source_ids := filters.get("source_ids"):
|
||||||
tags: list[str] | None = None,
|
return source_ids
|
||||||
):
|
|
||||||
if not tags:
|
tags = filters.get("tags")
|
||||||
|
size = filters.get("size")
|
||||||
|
if not (tags or size):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
@ -62,6 +66,8 @@ def filter_source_ids(
|
|||||||
items_query = items_query.filter(
|
items_query = items_query.filter(
|
||||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||||
)
|
)
|
||||||
|
if size:
|
||||||
|
items_query = items_query.filter(SourceItem.size == size)
|
||||||
if modalities:
|
if modalities:
|
||||||
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
||||||
source_ids = [item.id for item in items_query.all()]
|
source_ids = [item.id for item in items_query.all()]
|
||||||
@ -69,51 +75,12 @@ def filter_source_ids(
|
|||||||
return source_ids
|
return source_ids
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_tags() -> list[str]:
|
|
||||||
"""
|
|
||||||
Get all unique tags used across the entire knowledge base.
|
|
||||||
Returns sorted list of tags from both observations and content.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
|
||||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_subjects() -> list[str]:
|
|
||||||
"""
|
|
||||||
Get all unique subjects from observations about the user.
|
|
||||||
Returns sorted list of subject identifiers used in observations.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
return sorted(
|
|
||||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def get_all_observation_types() -> list[str]:
|
|
||||||
"""
|
|
||||||
Get all observation types that have been used.
|
|
||||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
|
||||||
"""
|
|
||||||
with make_session() as session:
|
|
||||||
return sorted(
|
|
||||||
{
|
|
||||||
r.observation_type
|
|
||||||
for r in session.query(AgentObservation.observation_type).distinct()
|
|
||||||
if r.observation_type is not None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def search_knowledge_base(
|
async def search_knowledge_base(
|
||||||
query: str,
|
query: str,
|
||||||
previews: bool = False,
|
filters: dict[str, Any],
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
tags: list[str] = [],
|
previews: bool = False,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
@ -125,7 +92,7 @@ async def search_knowledge_base(
|
|||||||
query: Natural language search query - be descriptive about what you're looking for
|
query: Natural language search query - be descriptive about what you're looking for
|
||||||
previews: Include actual content in results - when false only a snippet is returned
|
previews: Include actual content in results - when false only a snippet is returned
|
||||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||||
tags: Filter by tags - content must have at least one matching tag
|
filters: Filter by tags, source_ids, etc.
|
||||||
limit: Max results (1-100)
|
limit: Max results (1-100)
|
||||||
|
|
||||||
Returns: List of search results with id, score, chunks, content, filename
|
Returns: List of search results with id, score, chunks, content, filename
|
||||||
@ -137,6 +104,9 @@ async def search_knowledge_base(
|
|||||||
modalities = set(ALL_COLLECTIONS.keys())
|
modalities = set(ALL_COLLECTIONS.keys())
|
||||||
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
||||||
|
|
||||||
|
search_filters = SearchFilters(**filters)
|
||||||
|
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
|
||||||
|
|
||||||
upload_data = extract.extract_text(query)
|
upload_data = extract.extract_text(query)
|
||||||
results = await search(
|
results = await search(
|
||||||
upload_data,
|
upload_data,
|
||||||
@ -145,10 +115,7 @@ async def search_knowledge_base(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
min_text_score=0.4,
|
min_text_score=0.4,
|
||||||
min_multimodal_score=0.25,
|
min_multimodal_score=0.25,
|
||||||
filters=SearchFilters(
|
filters=search_filters,
|
||||||
tags=tags,
|
|
||||||
source_ids=filter_source_ids(tags=tags, modalities=modalities),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result.model_dump() for result in results]
|
return [result.model_dump() for result in results]
|
||||||
|
119
src/memory/api/MCP/metadata.py
Normal file
119
src/memory/api/MCP/metadata.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Annotated, TypedDict, get_args, get_type_hints
|
||||||
|
|
||||||
|
from memory.common import qdrant
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from memory.api.MCP.tools import mcp
|
||||||
|
from memory.common.db.connection import make_session
|
||||||
|
from memory.common.db.models import SourceItem
|
||||||
|
from memory.common.db.models.source_items import AgentObservation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaArg(TypedDict):
|
||||||
|
type: str | None
|
||||||
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionMetadata(TypedDict):
|
||||||
|
schema: dict[str, SchemaArg]
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
||||||
|
try:
|
||||||
|
type_, description = get_args(annotation)
|
||||||
|
type_str = str(type_)
|
||||||
|
if type_str.startswith("typing."):
|
||||||
|
type_str = type_str[7:]
|
||||||
|
elif len((parts := type_str.split("'"))) > 1:
|
||||||
|
type_str = parts[1]
|
||||||
|
return SchemaArg(type=type_str, description=description)
|
||||||
|
except IndexError:
|
||||||
|
logger.error(f"Error from annotation: {annotation}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
||||||
|
if not hasattr(klass, "as_payload"):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
name: schema
|
||||||
|
for name, arg in payload_type.__annotations__.items()
|
||||||
|
if (schema := from_annotation(arg))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
||||||
|
"""Get the metadata schema for each collection used in the knowledge base.
|
||||||
|
|
||||||
|
These schemas can be used to filter the knowledge base.
|
||||||
|
|
||||||
|
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
||||||
|
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
client = qdrant.get_qdrant_client()
|
||||||
|
sizes = qdrant.get_collection_sizes(client)
|
||||||
|
schemas = defaultdict(dict)
|
||||||
|
for klass in SourceItem.__subclasses__():
|
||||||
|
for collection in klass.get_collections():
|
||||||
|
schemas[collection].update(get_schema(klass))
|
||||||
|
|
||||||
|
return {
|
||||||
|
collection: CollectionMetadata(schema=schema, size=size)
|
||||||
|
for collection, schema in schemas.items()
|
||||||
|
if (size := sizes.get(collection))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_tags() -> list[str]:
|
||||||
|
"""Get all unique tags used across the entire knowledge base.
|
||||||
|
|
||||||
|
Returns sorted list of tags from both observations and content.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||||
|
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_subjects() -> list[str]:
|
||||||
|
"""Get all unique subjects from observations about the user.
|
||||||
|
|
||||||
|
Returns sorted list of subject identifiers used in observations.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_all_observation_types() -> list[str]:
|
||||||
|
"""Get all observation types that have been used.
|
||||||
|
|
||||||
|
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||||
|
"""
|
||||||
|
with make_session() as session:
|
||||||
|
return sorted(
|
||||||
|
{
|
||||||
|
r.observation_type
|
||||||
|
for r in session.query(AgentObservation.observation_type).distinct()
|
||||||
|
if r.observation_type is not None
|
||||||
|
}
|
||||||
|
)
|
@ -71,6 +71,13 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
|||||||
# SQLAdmin setup with OAuth protection
|
# SQLAdmin setup with OAuth protection
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
admin = Admin(app, engine)
|
admin = Admin(app, engine)
|
||||||
|
admin.app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # [settings.SERVER_URL],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
# Setup admin with OAuth protection using existing OAuth provider
|
# Setup admin with OAuth protection using existing OAuth provider
|
||||||
setup_admin(admin)
|
setup_admin(admin)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional, cast
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -90,13 +90,71 @@ def query_chunks(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def merge_range_filter(
|
||||||
|
filters: list[dict[str, Any]], key: str, val: Any
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
direction, field = key.split("_", maxsplit=1)
|
||||||
|
item = next((f for f in filters if f["key"] == field), None)
|
||||||
|
if not item:
|
||||||
|
item = {"key": field, "range": {}}
|
||||||
|
filters.append(item)
|
||||||
|
|
||||||
|
if direction == "min":
|
||||||
|
item["range"]["gte"] = val
|
||||||
|
elif direction == "max":
|
||||||
|
item["range"]["lte"] = val
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def merge_filters(
|
||||||
|
filters: list[dict[str, Any]], key: str, val: Any
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not val and val != 0:
|
||||||
|
return filters
|
||||||
|
|
||||||
|
list_filters = ["tags", "recipients", "observation_types", "authors"]
|
||||||
|
range_filters = [
|
||||||
|
"min_sent_at",
|
||||||
|
"max_sent_at",
|
||||||
|
"min_published",
|
||||||
|
"max_published",
|
||||||
|
"min_size",
|
||||||
|
"max_size",
|
||||||
|
"min_created_at",
|
||||||
|
"max_created_at",
|
||||||
|
]
|
||||||
|
if key in list_filters:
|
||||||
|
filters.append({"key": key, "match": {"any": val}})
|
||||||
|
|
||||||
|
elif key in range_filters:
|
||||||
|
return merge_range_filter(filters, key, val)
|
||||||
|
|
||||||
|
elif key == "min_confidences":
|
||||||
|
confidence_filters = [
|
||||||
|
{
|
||||||
|
"key": f"confidence.{confidence_type}",
|
||||||
|
"range": {"gte": min_confidence_score},
|
||||||
|
}
|
||||||
|
for confidence_type, min_confidence_score in cast(dict, val).items()
|
||||||
|
]
|
||||||
|
filters.extend(confidence_filters)
|
||||||
|
|
||||||
|
elif key == "source_ids":
|
||||||
|
filters.append({"key": "id", "match": {"any": val}})
|
||||||
|
|
||||||
|
else:
|
||||||
|
filters.append({"key": key, "match": val})
|
||||||
|
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
async def search_embeddings(
|
async def search_embeddings(
|
||||||
data: list[extract.DataChunk],
|
data: list[extract.DataChunk],
|
||||||
previews: Optional[bool] = False,
|
previews: Optional[bool] = False,
|
||||||
modalities: set[str] = set(),
|
modalities: set[str] = set(),
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
min_score: float = 0.3,
|
min_score: float = 0.3,
|
||||||
filters: SearchFilters = SearchFilters(),
|
filters: SearchFilters = {},
|
||||||
multimodal: bool = False,
|
multimodal: bool = False,
|
||||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||||
"""
|
"""
|
||||||
@ -111,27 +169,11 @@ async def search_embeddings(
|
|||||||
- filters: Filters to apply to the search results
|
- filters: Filters to apply to the search results
|
||||||
- multimodal: Whether to search in multimodal collections
|
- multimodal: Whether to search in multimodal collections
|
||||||
"""
|
"""
|
||||||
query_filters = {"must": []}
|
search_filters = []
|
||||||
|
for key, val in filters.items():
|
||||||
# Handle structured confidence filtering
|
search_filters = merge_filters(search_filters, key, val)
|
||||||
if min_confidences := filters.get("min_confidences"):
|
|
||||||
confidence_filters = [
|
|
||||||
{
|
|
||||||
"key": f"confidence.{confidence_type}",
|
|
||||||
"range": {"gte": min_confidence_score},
|
|
||||||
}
|
|
||||||
for confidence_type, min_confidence_score in min_confidences.items()
|
|
||||||
]
|
|
||||||
query_filters["must"].extend(confidence_filters)
|
|
||||||
|
|
||||||
if tags := filters.get("tags"):
|
|
||||||
query_filters["must"].append({"key": "tags", "match": {"any": tags}})
|
|
||||||
|
|
||||||
if observation_types := filters.get("observation_types"):
|
|
||||||
query_filters["must"].append(
|
|
||||||
{"key": "observation_type", "match": {"any": observation_types}}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
print(search_filters)
|
||||||
client = qdrant.get_qdrant_client()
|
client = qdrant.get_qdrant_client()
|
||||||
results = query_chunks(
|
results = query_chunks(
|
||||||
client,
|
client,
|
||||||
@ -140,7 +182,7 @@ async def search_embeddings(
|
|||||||
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
filters=query_filters if query_filters["must"] else None,
|
filters={"must": search_filters} if search_filters else None,
|
||||||
)
|
)
|
||||||
search_results = {k: results.get(k, []) for k in modalities}
|
search_results = {k: results.get(k, []) for k in modalities}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from memory.common.collections import (
|
|||||||
MULTIMODAL_COLLECTIONS,
|
MULTIMODAL_COLLECTIONS,
|
||||||
TEXT_COLLECTIONS,
|
TEXT_COLLECTIONS,
|
||||||
)
|
)
|
||||||
|
from memory.common import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -44,51 +45,57 @@ async def search(
|
|||||||
- List of search results sorted by score
|
- List of search results sorted by score
|
||||||
"""
|
"""
|
||||||
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
|
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
|
||||||
|
print(allowed_modalities)
|
||||||
|
|
||||||
text_embeddings_results = with_timeout(
|
searches = []
|
||||||
search_embeddings(
|
if settings.ENABLE_EMBEDDING_SEARCH:
|
||||||
data,
|
searches = [
|
||||||
previews,
|
with_timeout(
|
||||||
allowed_modalities & TEXT_COLLECTIONS,
|
search_embeddings(
|
||||||
limit,
|
data,
|
||||||
min_text_score,
|
previews,
|
||||||
filters,
|
allowed_modalities & TEXT_COLLECTIONS,
|
||||||
multimodal=False,
|
limit,
|
||||||
),
|
min_text_score,
|
||||||
timeout,
|
filters,
|
||||||
)
|
multimodal=False,
|
||||||
multimodal_embeddings_results = with_timeout(
|
),
|
||||||
search_embeddings(
|
timeout,
|
||||||
data,
|
),
|
||||||
previews,
|
with_timeout(
|
||||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
search_embeddings(
|
||||||
limit,
|
data,
|
||||||
min_multimodal_score,
|
previews,
|
||||||
filters,
|
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||||
multimodal=True,
|
limit,
|
||||||
),
|
min_multimodal_score,
|
||||||
timeout,
|
filters,
|
||||||
)
|
multimodal=True,
|
||||||
bm25_results = with_timeout(
|
),
|
||||||
search_bm25(
|
timeout,
|
||||||
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
),
|
||||||
modalities,
|
]
|
||||||
limit=limit,
|
if settings.ENABLE_BM25_SEARCH:
|
||||||
filters=filters,
|
searches.append(
|
||||||
),
|
with_timeout(
|
||||||
timeout,
|
search_bm25(
|
||||||
)
|
" ".join(
|
||||||
|
[c for chunk in data for c in chunk.data if isinstance(c, str)]
|
||||||
|
),
|
||||||
|
modalities,
|
||||||
|
limit=limit,
|
||||||
|
filters=filters,
|
||||||
|
),
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
results = await asyncio.gather(
|
search_results = await asyncio.gather(*searches, return_exceptions=False)
|
||||||
text_embeddings_results,
|
all_results = []
|
||||||
multimodal_embeddings_results,
|
for results in search_results:
|
||||||
bm25_results,
|
if len(all_results) >= limit:
|
||||||
return_exceptions=False,
|
break
|
||||||
)
|
all_results.extend(results)
|
||||||
text_results, multi_results, bm25_results = results
|
|
||||||
all_results = text_results + multi_results
|
|
||||||
if len(all_results) < limit:
|
|
||||||
all_results += bm25_results
|
|
||||||
|
|
||||||
results = group_chunks(all_results, previews or False)
|
results = group_chunks(all_results, previews or False)
|
||||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||||
|
@ -65,9 +65,9 @@ class SearchResult(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SearchFilters(TypedDict):
|
class SearchFilters(TypedDict):
|
||||||
subject: NotRequired[str | None]
|
min_size: NotRequired[int]
|
||||||
|
max_size: NotRequired[int]
|
||||||
min_confidences: NotRequired[dict[str, float]]
|
min_confidences: NotRequired[dict[str, float]]
|
||||||
tags: NotRequired[list[str] | None]
|
|
||||||
observation_types: NotRequired[list[str] | None]
|
observation_types: NotRequired[list[str] | None]
|
||||||
source_ids: NotRequired[list[int] | None]
|
source_ids: NotRequired[list[int] | None]
|
||||||
|
|
||||||
@ -115,7 +115,6 @@ def group_chunks(
|
|||||||
if isinstance(contents, dict):
|
if isinstance(contents, dict):
|
||||||
tags = contents.pop("tags", [])
|
tags = contents.pop("tags", [])
|
||||||
content = contents.pop("content", None)
|
content = contents.pop("content", None)
|
||||||
print(content)
|
|
||||||
else:
|
else:
|
||||||
content = contents
|
content = contents
|
||||||
contents = {}
|
contents = {}
|
||||||
|
@ -4,6 +4,7 @@ from memory.common.db.models.source_item import (
|
|||||||
SourceItem,
|
SourceItem,
|
||||||
ConfidenceScore,
|
ConfidenceScore,
|
||||||
clean_filename,
|
clean_filename,
|
||||||
|
SourceItemPayload,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.source_items import (
|
from memory.common.db.models.source_items import (
|
||||||
MailMessage,
|
MailMessage,
|
||||||
@ -19,6 +20,14 @@ from memory.common.db.models.source_items import (
|
|||||||
Photo,
|
Photo,
|
||||||
MiscDoc,
|
MiscDoc,
|
||||||
Note,
|
Note,
|
||||||
|
MailMessagePayload,
|
||||||
|
EmailAttachmentPayload,
|
||||||
|
AgentObservationPayload,
|
||||||
|
BlogPostPayload,
|
||||||
|
ComicPayload,
|
||||||
|
BookSectionPayload,
|
||||||
|
NotePayload,
|
||||||
|
ForumPostPayload,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.observations import (
|
from memory.common.db.models.observations import (
|
||||||
ObservationContradiction,
|
ObservationContradiction,
|
||||||
@ -40,6 +49,18 @@ from memory.common.db.models.users import (
|
|||||||
OAuthRefreshToken,
|
OAuthRefreshToken,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Payload = (
|
||||||
|
SourceItemPayload
|
||||||
|
| AgentObservationPayload
|
||||||
|
| NotePayload
|
||||||
|
| BlogPostPayload
|
||||||
|
| ComicPayload
|
||||||
|
| BookSectionPayload
|
||||||
|
| ForumPostPayload
|
||||||
|
| EmailAttachmentPayload
|
||||||
|
| MailMessagePayload
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Base",
|
"Base",
|
||||||
"Chunk",
|
"Chunk",
|
||||||
@ -75,4 +96,6 @@ __all__ = [
|
|||||||
"OAuthClientInformation",
|
"OAuthClientInformation",
|
||||||
"OAuthState",
|
"OAuthState",
|
||||||
"OAuthRefreshToken",
|
"OAuthRefreshToken",
|
||||||
|
# Payloads
|
||||||
|
"Payload",
|
||||||
]
|
]
|
||||||
|
@ -4,7 +4,7 @@ Database models for the knowledge base system.
|
|||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from typing import Any, Sequence, cast
|
from typing import Any, Annotated, Sequence, TypedDict, cast
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -36,6 +36,17 @@ import memory.common.summarizer as summarizer
|
|||||||
from memory.common.db.models.base import Base
|
from memory.common.db.models.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataSchema(TypedDict):
|
||||||
|
type: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class SourceItemPayload(TypedDict):
|
||||||
|
source_id: Annotated[int, "Unique identifier of the source item"]
|
||||||
|
tags: Annotated[list[str], "List of tags for categorization"]
|
||||||
|
size: Annotated[int | None, "Size of the content in bytes"]
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(Session, "before_flush")
|
@event.listens_for(Session, "before_flush")
|
||||||
def handle_duplicate_sha256(session, flush_context, instances):
|
def handle_duplicate_sha256(session, flush_context, instances):
|
||||||
"""
|
"""
|
||||||
@ -344,12 +355,17 @@ class SourceItem(Base):
|
|||||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||||
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
|
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> SourceItemPayload:
|
||||||
return {
|
return SourceItemPayload(
|
||||||
"source_id": self.id,
|
source_id=cast(int, self.id),
|
||||||
"tags": self.tags,
|
tags=cast(list[str], self.tags),
|
||||||
"size": self.size,
|
size=cast(int | None, self.size),
|
||||||
}
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collections(cls) -> list[str]:
|
||||||
|
"""Return the list of Qdrant collections this SourceItem type can be stored in."""
|
||||||
|
return [cls.__tablename__]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_contents(self) -> str | dict | None:
|
def display_contents(self) -> str | dict | None:
|
||||||
|
@ -5,7 +5,7 @@ Database models for the knowledge base system.
|
|||||||
import pathlib
|
import pathlib
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Sequence, cast
|
from typing import Any, Annotated, Sequence, cast
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
@ -32,11 +32,21 @@ import memory.common.formatters.observation as observation
|
|||||||
from memory.common.db.models.source_item import (
|
from memory.common.db.models.source_item import (
|
||||||
SourceItem,
|
SourceItem,
|
||||||
Chunk,
|
Chunk,
|
||||||
|
SourceItemPayload,
|
||||||
clean_filename,
|
clean_filename,
|
||||||
chunk_mixed,
|
chunk_mixed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MailMessagePayload(SourceItemPayload):
|
||||||
|
message_id: Annotated[str, "Unique email message identifier"]
|
||||||
|
subject: Annotated[str, "Email subject line"]
|
||||||
|
sender: Annotated[str, "Email sender address"]
|
||||||
|
recipients: Annotated[list[str], "List of recipient email addresses"]
|
||||||
|
folder: Annotated[str, "Email folder name"]
|
||||||
|
date: Annotated[str | None, "Email sent date in ISO format"]
|
||||||
|
|
||||||
|
|
||||||
class MailMessage(SourceItem):
|
class MailMessage(SourceItem):
|
||||||
__tablename__ = "mail_message"
|
__tablename__ = "mail_message"
|
||||||
|
|
||||||
@ -80,17 +90,21 @@ class MailMessage(SourceItem):
|
|||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> MailMessagePayload:
|
||||||
return {
|
base_payload = super().as_payload() | {
|
||||||
**super().as_payload(),
|
"tags": cast(list[str], self.tags)
|
||||||
"message_id": self.message_id,
|
+ [cast(str, self.sender)]
|
||||||
"subject": self.subject,
|
+ cast(list[str], self.recipients)
|
||||||
"sender": self.sender,
|
|
||||||
"recipients": self.recipients,
|
|
||||||
"folder": self.folder,
|
|
||||||
"tags": self.tags + [self.sender] + self.recipients,
|
|
||||||
"date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
|
|
||||||
}
|
}
|
||||||
|
return MailMessagePayload(
|
||||||
|
**cast(dict, base_payload),
|
||||||
|
message_id=cast(str, self.message_id),
|
||||||
|
subject=cast(str, self.subject),
|
||||||
|
sender=cast(str, self.sender),
|
||||||
|
recipients=cast(list[str], self.recipients),
|
||||||
|
folder=cast(str, self.folder),
|
||||||
|
date=(self.sent_at and self.sent_at.isoformat() or None), # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parsed_content(self) -> dict[str, Any]:
|
def parsed_content(self) -> dict[str, Any]:
|
||||||
@ -152,7 +166,7 @@ class MailMessage(SourceItem):
|
|||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
content = self.parsed_content
|
content = self.parsed_content
|
||||||
chunks = extract.extract_text(cast(str, self.body))
|
chunks = extract.extract_text(cast(str, self.body), modality="email")
|
||||||
|
|
||||||
def add_header(item: extract.MulitmodalChunk) -> extract.MulitmodalChunk:
|
def add_header(item: extract.MulitmodalChunk) -> extract.MulitmodalChunk:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
@ -163,6 +177,10 @@ class MailMessage(SourceItem):
|
|||||||
chunk.data = [add_header(item) for item in chunk.data]
|
chunk.data = [add_header(item) for item in chunk.data]
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collections(cls) -> list[str]:
|
||||||
|
return ["mail"]
|
||||||
|
|
||||||
# Add indexes
|
# Add indexes
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("mail_sent_idx", "sent_at"),
|
Index("mail_sent_idx", "sent_at"),
|
||||||
@ -171,6 +189,13 @@ class MailMessage(SourceItem):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailAttachmentPayload(SourceItemPayload):
|
||||||
|
filename: Annotated[str, "Name of the document file"]
|
||||||
|
content_type: Annotated[str, "MIME type of the document"]
|
||||||
|
mail_message_id: Annotated[int, "Associated email message ID"]
|
||||||
|
sent_at: Annotated[str | None, "Document creation timestamp"]
|
||||||
|
|
||||||
|
|
||||||
class EmailAttachment(SourceItem):
|
class EmailAttachment(SourceItem):
|
||||||
__tablename__ = "email_attachment"
|
__tablename__ = "email_attachment"
|
||||||
|
|
||||||
@ -190,17 +215,20 @@ class EmailAttachment(SourceItem):
|
|||||||
"polymorphic_identity": "email_attachment",
|
"polymorphic_identity": "email_attachment",
|
||||||
}
|
}
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> EmailAttachmentPayload:
|
||||||
return {
|
return EmailAttachmentPayload(
|
||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"filename": self.filename,
|
filename=cast(str, self.filename),
|
||||||
"content_type": self.mime_type,
|
content_type=cast(str, self.mime_type),
|
||||||
"size": self.size,
|
mail_message_id=cast(int, self.mail_message_id),
|
||||||
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
|
sent_at=(
|
||||||
"mail_message_id": self.mail_message_id,
|
self.mail_message.sent_at
|
||||||
}
|
and self.mail_message.sent_at.isoformat()
|
||||||
|
or None
|
||||||
|
), # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
if cast(str | None, self.filename):
|
if cast(str | None, self.filename):
|
||||||
contents = (
|
contents = (
|
||||||
settings.FILE_STORAGE_DIR / cast(str, self.filename)
|
settings.FILE_STORAGE_DIR / cast(str, self.filename)
|
||||||
@ -208,8 +236,7 @@ class EmailAttachment(SourceItem):
|
|||||||
else:
|
else:
|
||||||
contents = cast(str, self.content)
|
contents = cast(str, self.content)
|
||||||
|
|
||||||
chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents)
|
return extract.extract_data_chunks(cast(str, self.mime_type), contents)
|
||||||
return [self._make_chunk(c, metadata) for c in chunks]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_contents(self) -> dict:
|
def display_contents(self) -> dict:
|
||||||
@ -221,6 +248,11 @@ class EmailAttachment(SourceItem):
|
|||||||
# Add indexes
|
# Add indexes
|
||||||
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collections(cls) -> list[str]:
|
||||||
|
"""EmailAttachment can go to different collections based on mime_type"""
|
||||||
|
return ["doc", "text", "blog", "photo", "book"]
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(SourceItem):
|
class ChatMessage(SourceItem):
|
||||||
__tablename__ = "chat_message"
|
__tablename__ = "chat_message"
|
||||||
@ -285,6 +317,16 @@ class Photo(SourceItem):
|
|||||||
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
|
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
|
||||||
|
|
||||||
|
|
||||||
|
class ComicPayload(SourceItemPayload):
|
||||||
|
title: Annotated[str, "Title of the comic"]
|
||||||
|
author: Annotated[str | None, "Author of the comic"]
|
||||||
|
published: Annotated[str | None, "Publication date in ISO format"]
|
||||||
|
volume: Annotated[str | None, "Volume number"]
|
||||||
|
issue: Annotated[str | None, "Issue number"]
|
||||||
|
page: Annotated[int | None, "Page number"]
|
||||||
|
url: Annotated[str | None, "URL of the comic"]
|
||||||
|
|
||||||
|
|
||||||
class Comic(SourceItem):
|
class Comic(SourceItem):
|
||||||
__tablename__ = "comic"
|
__tablename__ = "comic"
|
||||||
|
|
||||||
@ -305,18 +347,17 @@ class Comic(SourceItem):
|
|||||||
|
|
||||||
__table_args__ = (Index("comic_author_idx", "author"),)
|
__table_args__ = (Index("comic_author_idx", "author"),)
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> ComicPayload:
|
||||||
payload = {
|
return ComicPayload(
|
||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"title": self.title,
|
title=cast(str, self.title),
|
||||||
"author": self.author,
|
author=cast(str | None, self.author),
|
||||||
"published": self.published,
|
published=(self.published and self.published.isoformat() or None), # type: ignore
|
||||||
"volume": self.volume,
|
volume=cast(str | None, self.volume),
|
||||||
"issue": self.issue,
|
issue=cast(str | None, self.issue),
|
||||||
"page": self.page,
|
page=cast(int | None, self.page),
|
||||||
"url": self.url,
|
url=cast(str | None, self.url),
|
||||||
}
|
)
|
||||||
return {k: v for k, v in payload.items() if v is not None}
|
|
||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename))
|
image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename))
|
||||||
@ -324,6 +365,17 @@ class Comic(SourceItem):
|
|||||||
return [extract.DataChunk(data=[image, description])]
|
return [extract.DataChunk(data=[image, description])]
|
||||||
|
|
||||||
|
|
||||||
|
class BookSectionPayload(SourceItemPayload):
|
||||||
|
title: Annotated[str, "Title of the book"]
|
||||||
|
author: Annotated[str | None, "Author of the book"]
|
||||||
|
book_id: Annotated[int, "Unique identifier of the book"]
|
||||||
|
section_title: Annotated[str, "Title of the section"]
|
||||||
|
section_number: Annotated[int, "Number of the section"]
|
||||||
|
section_level: Annotated[int, "Level of the section"]
|
||||||
|
start_page: Annotated[int, "Starting page number"]
|
||||||
|
end_page: Annotated[int, "Ending page number"]
|
||||||
|
|
||||||
|
|
||||||
class BookSection(SourceItem):
|
class BookSection(SourceItem):
|
||||||
"""Individual sections/chapters of books"""
|
"""Individual sections/chapters of books"""
|
||||||
|
|
||||||
@ -361,19 +413,22 @@ class BookSection(SourceItem):
|
|||||||
Index("book_section_level_idx", "section_level", "section_number"),
|
Index("book_section_level_idx", "section_level", "section_number"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
@classmethod
|
||||||
vals = {
|
def get_collections(cls) -> list[str]:
|
||||||
|
return ["book"]
|
||||||
|
|
||||||
|
def as_payload(self) -> BookSectionPayload:
|
||||||
|
return BookSectionPayload(
|
||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"title": self.book.title,
|
title=cast(str, self.book.title),
|
||||||
"author": self.book.author,
|
author=cast(str | None, self.book.author),
|
||||||
"book_id": self.book_id,
|
book_id=cast(int, self.book_id),
|
||||||
"section_title": self.section_title,
|
section_title=cast(str, self.section_title),
|
||||||
"section_number": self.section_number,
|
section_number=cast(int, self.section_number),
|
||||||
"section_level": self.section_level,
|
section_level=cast(int, self.section_level),
|
||||||
"start_page": self.start_page,
|
start_page=cast(int, self.start_page),
|
||||||
"end_page": self.end_page,
|
end_page=cast(int, self.end_page),
|
||||||
}
|
)
|
||||||
return {k: v for k, v in vals.items() if v}
|
|
||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
content = cast(str, self.content.strip())
|
content = cast(str, self.content.strip())
|
||||||
@ -397,6 +452,16 @@ class BookSection(SourceItem):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BlogPostPayload(SourceItemPayload):
|
||||||
|
url: Annotated[str, "URL of the blog post"]
|
||||||
|
title: Annotated[str, "Title of the blog post"]
|
||||||
|
author: Annotated[str | None, "Author of the blog post"]
|
||||||
|
published: Annotated[str | None, "Publication date in ISO format"]
|
||||||
|
description: Annotated[str | None, "Description of the blog post"]
|
||||||
|
domain: Annotated[str | None, "Domain of the blog post"]
|
||||||
|
word_count: Annotated[int | None, "Word count of the blog post"]
|
||||||
|
|
||||||
|
|
||||||
class BlogPost(SourceItem):
|
class BlogPost(SourceItem):
|
||||||
__tablename__ = "blog_post"
|
__tablename__ = "blog_post"
|
||||||
|
|
||||||
@ -428,27 +493,39 @@ class BlogPost(SourceItem):
|
|||||||
Index("blog_post_word_count_idx", "word_count"),
|
Index("blog_post_word_count_idx", "word_count"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> BlogPostPayload:
|
||||||
published_date = cast(datetime | None, self.published)
|
published_date = cast(datetime | None, self.published)
|
||||||
metadata = cast(dict | None, self.webpage_metadata) or {}
|
metadata = cast(dict | None, self.webpage_metadata) or {}
|
||||||
|
|
||||||
payload = {
|
return BlogPostPayload(
|
||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"url": self.url,
|
url=cast(str, self.url),
|
||||||
"title": self.title,
|
title=cast(str, self.title),
|
||||||
"author": self.author,
|
author=cast(str | None, self.author),
|
||||||
"published": published_date and published_date.isoformat(),
|
published=(published_date and published_date.isoformat() or None), # type: ignore
|
||||||
"description": self.description,
|
description=cast(str | None, self.description),
|
||||||
"domain": self.domain,
|
domain=cast(str | None, self.domain),
|
||||||
"word_count": self.word_count,
|
word_count=cast(int | None, self.word_count),
|
||||||
**metadata,
|
**metadata,
|
||||||
}
|
)
|
||||||
return {k: v for k, v in payload.items() if v}
|
|
||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||||
|
|
||||||
|
|
||||||
|
class ForumPostPayload(SourceItemPayload):
|
||||||
|
url: Annotated[str, "URL of the forum post"]
|
||||||
|
title: Annotated[str, "Title of the forum post"]
|
||||||
|
description: Annotated[str | None, "Description of the forum post"]
|
||||||
|
authors: Annotated[list[str] | None, "Authors of the forum post"]
|
||||||
|
published: Annotated[str | None, "Publication date in ISO format"]
|
||||||
|
slug: Annotated[str | None, "Slug of the forum post"]
|
||||||
|
karma: Annotated[int | None, "Karma score of the forum post"]
|
||||||
|
votes: Annotated[int | None, "Number of votes on the forum post"]
|
||||||
|
score: Annotated[int | None, "Score of the forum post"]
|
||||||
|
comments: Annotated[int | None, "Number of comments on the forum post"]
|
||||||
|
|
||||||
|
|
||||||
class ForumPost(SourceItem):
|
class ForumPost(SourceItem):
|
||||||
__tablename__ = "forum_post"
|
__tablename__ = "forum_post"
|
||||||
|
|
||||||
@ -479,20 +556,20 @@ class ForumPost(SourceItem):
|
|||||||
Index("forum_post_title_idx", "title"),
|
Index("forum_post_title_idx", "title"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> ForumPostPayload:
|
||||||
return {
|
return ForumPostPayload(
|
||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"url": self.url,
|
url=cast(str, self.url),
|
||||||
"title": self.title,
|
title=cast(str, self.title),
|
||||||
"description": self.description,
|
description=cast(str | None, self.description),
|
||||||
"authors": self.authors,
|
authors=cast(list[str] | None, self.authors),
|
||||||
"published_at": self.published_at,
|
published=(self.published_at and self.published_at.isoformat() or None), # type: ignore
|
||||||
"slug": self.slug,
|
slug=cast(str | None, self.slug),
|
||||||
"karma": self.karma,
|
karma=cast(int | None, self.karma),
|
||||||
"votes": self.votes,
|
votes=cast(int | None, self.votes),
|
||||||
"score": self.score,
|
score=cast(int | None, self.score),
|
||||||
"comments": self.comments,
|
comments=cast(int | None, self.comments),
|
||||||
}
|
)
|
||||||
|
|
||||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||||
@ -545,6 +622,12 @@ class GithubItem(SourceItem):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NotePayload(SourceItemPayload):
|
||||||
|
note_type: Annotated[str | None, "Category of the note"]
|
||||||
|
subject: Annotated[str | None, "What the note is about"]
|
||||||
|
confidence: Annotated[dict[str, float], "Confidence scores for the note"]
|
||||||
|
|
||||||
|
|
||||||
class Note(SourceItem):
|
class Note(SourceItem):
|
||||||
"""A quick note of something of interest."""
|
"""A quick note of something of interest."""
|
||||||
|
|
||||||
@ -565,13 +648,13 @@ class Note(SourceItem):
|
|||||||
Index("note_subject_idx", "subject"),
|
Index("note_subject_idx", "subject"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> NotePayload:
|
||||||
return {
|
return NotePayload(
|
||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"note_type": self.note_type,
|
note_type=cast(str | None, self.note_type),
|
||||||
"subject": self.subject,
|
subject=cast(str | None, self.subject),
|
||||||
"confidence": self.confidence_dict,
|
confidence=self.confidence_dict,
|
||||||
}
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_contents(self) -> dict:
|
def display_contents(self) -> dict:
|
||||||
@ -602,6 +685,19 @@ class Note(SourceItem):
|
|||||||
self.as_text(cast(str, self.content), cast(str | None, self.subject))
|
self.as_text(cast(str, self.content), cast(str | None, self.subject))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collections(cls) -> list[str]:
|
||||||
|
return ["text"] # Notes go to the text collection
|
||||||
|
|
||||||
|
|
||||||
|
class AgentObservationPayload(SourceItemPayload):
|
||||||
|
session_id: Annotated[str | None, "Session ID for the observation"]
|
||||||
|
observation_type: Annotated[str, "Type of observation"]
|
||||||
|
subject: Annotated[str, "What/who the observation is about"]
|
||||||
|
confidence: Annotated[dict[str, float], "Confidence scores for the observation"]
|
||||||
|
evidence: Annotated[dict | None, "Supporting context, quotes, etc."]
|
||||||
|
agent_model: Annotated[str, "Which AI model made this observation"]
|
||||||
|
|
||||||
|
|
||||||
class AgentObservation(SourceItem):
|
class AgentObservation(SourceItem):
|
||||||
"""
|
"""
|
||||||
@ -652,18 +748,16 @@ class AgentObservation(SourceItem):
|
|||||||
kwargs["modality"] = "observation"
|
kwargs["modality"] = "observation"
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def as_payload(self) -> dict:
|
def as_payload(self) -> AgentObservationPayload:
|
||||||
payload = {
|
return AgentObservationPayload(
|
||||||
**super().as_payload(),
|
**super().as_payload(),
|
||||||
"observation_type": self.observation_type,
|
observation_type=cast(str, self.observation_type),
|
||||||
"subject": self.subject,
|
subject=cast(str, self.subject),
|
||||||
"confidence": self.confidence_dict,
|
confidence=self.confidence_dict,
|
||||||
"evidence": self.evidence,
|
evidence=cast(dict | None, self.evidence),
|
||||||
"agent_model": self.agent_model,
|
agent_model=cast(str, self.agent_model),
|
||||||
}
|
session_id=cast(str | None, self.session_id) and str(self.session_id),
|
||||||
if self.session_id is not None:
|
)
|
||||||
payload["session_id"] = str(self.session_id)
|
|
||||||
return payload
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_contradictions(self):
|
def all_contradictions(self):
|
||||||
@ -759,3 +853,7 @@ class AgentObservation(SourceItem):
|
|||||||
# ))
|
# ))
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collections(cls) -> list[str]:
|
||||||
|
return ["semantic", "temporal"]
|
||||||
|
@ -53,7 +53,9 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
def doc_to_images(
|
||||||
|
content: bytes | str | pathlib.Path, modality: str = "doc"
|
||||||
|
) -> list[DataChunk]:
|
||||||
with as_file(content) as file_path:
|
with as_file(content) as file_path:
|
||||||
with pymupdf.open(file_path) as pdf:
|
with pymupdf.open(file_path) as pdf:
|
||||||
return [
|
return [
|
||||||
@ -65,6 +67,7 @@ def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
|||||||
"height": page.rect.height,
|
"height": page.rect.height,
|
||||||
},
|
},
|
||||||
mime_type="image/jpeg",
|
mime_type="image/jpeg",
|
||||||
|
modality=modality,
|
||||||
)
|
)
|
||||||
for page in pdf.pages()
|
for page in pdf.pages()
|
||||||
]
|
]
|
||||||
@ -122,6 +125,7 @@ def extract_text(
|
|||||||
content: bytes | str | pathlib.Path,
|
content: bytes | str | pathlib.Path,
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
metadata: dict[str, Any] = {},
|
metadata: dict[str, Any] = {},
|
||||||
|
modality: str = "text",
|
||||||
) -> list[DataChunk]:
|
) -> list[DataChunk]:
|
||||||
if isinstance(content, pathlib.Path):
|
if isinstance(content, pathlib.Path):
|
||||||
content = content.read_text()
|
content = content.read_text()
|
||||||
@ -130,7 +134,7 @@ def extract_text(
|
|||||||
|
|
||||||
content = cast(str, content)
|
content = cast(str, content)
|
||||||
chunks = [
|
chunks = [
|
||||||
DataChunk(data=[c], modality="text", metadata=metadata)
|
DataChunk(data=[c], modality=modality, metadata=metadata)
|
||||||
for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
|
for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
|
||||||
]
|
]
|
||||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||||
@ -139,7 +143,7 @@ def extract_text(
|
|||||||
DataChunk(
|
DataChunk(
|
||||||
data=[summary],
|
data=[summary],
|
||||||
metadata=merge_metadata(metadata, {"tags": tags}),
|
metadata=merge_metadata(metadata, {"tags": tags}),
|
||||||
modality="text",
|
modality=modality,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return chunks
|
return chunks
|
||||||
@ -158,9 +162,7 @@ def extract_data_chunks(
|
|||||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
"application/msword",
|
"application/msword",
|
||||||
]:
|
]:
|
||||||
logger.info(f"Extracting content from {content}")
|
|
||||||
chunks = extract_docx(content)
|
chunks = extract_docx(content)
|
||||||
logger.info(f"Extracted {len(chunks)} pages from {content}")
|
|
||||||
elif mime_type.startswith("text/"):
|
elif mime_type.startswith("text/"):
|
||||||
chunks = extract_text(content, chunk_size)
|
chunks = extract_text(content, chunk_size)
|
||||||
elif mime_type.startswith("image/"):
|
elif mime_type.startswith("image/"):
|
||||||
|
@ -224,6 +224,15 @@ def get_collection_info(
|
|||||||
return info.model_dump()
|
return info.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection_sizes(client: qdrant_client.QdrantClient) -> dict[str, int]:
|
||||||
|
"""Get the size of each collection."""
|
||||||
|
collections = [i.name for i in client.get_collections().collections]
|
||||||
|
return {
|
||||||
|
collection_name: client.count(collection_name).count # type: ignore
|
||||||
|
for collection_name in collections
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def batch_ids(
|
def batch_ids(
|
||||||
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
|
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
|
||||||
) -> Generator[list[str], None, None]:
|
) -> Generator[list[str], None, None]:
|
||||||
|
@ -6,6 +6,8 @@ load_dotenv()
|
|||||||
|
|
||||||
|
|
||||||
def boolean_env(key: str, default: bool = False) -> bool:
|
def boolean_env(key: str, default: bool = False) -> bool:
|
||||||
|
if key not in os.environ:
|
||||||
|
return default
|
||||||
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
||||||
|
|
||||||
|
|
||||||
@ -130,6 +132,10 @@ else:
|
|||||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||||
|
|
||||||
|
# Search settings
|
||||||
|
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||||
|
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||||
|
|
||||||
# API settings
|
# API settings
|
||||||
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||||
HTTPS = boolean_env("HTTPS", False)
|
HTTPS = boolean_env("HTTPS", False)
|
||||||
|
174
tests/memory/api/search/test_search_embeddings.py
Normal file
174
tests/memory/api/search/test_search_embeddings.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
import pytest
|
||||||
|
from memory.api.search.embeddings import merge_range_filter, merge_filters
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_range_filter_new_filter():
|
||||||
|
"""Test adding new range filters"""
|
||||||
|
filters = []
|
||||||
|
result = merge_range_filter(filters, "min_size", 100)
|
||||||
|
assert result == [{"key": "size", "range": {"gte": 100}}]
|
||||||
|
|
||||||
|
filters = []
|
||||||
|
result = merge_range_filter(filters, "max_size", 1000)
|
||||||
|
assert result == [{"key": "size", "range": {"lte": 1000}}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_range_filter_existing_field():
|
||||||
|
"""Test adding to existing field"""
|
||||||
|
filters = [{"key": "size", "range": {"lte": 1000}}]
|
||||||
|
result = merge_range_filter(filters, "min_size", 100)
|
||||||
|
assert result == [{"key": "size", "range": {"lte": 1000, "gte": 100}}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_range_filter_override_existing():
|
||||||
|
"""Test overriding existing values"""
|
||||||
|
filters = [{"key": "size", "range": {"gte": 100}}]
|
||||||
|
result = merge_range_filter(filters, "min_size", 200)
|
||||||
|
assert result == [{"key": "size", "range": {"gte": 200}}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_range_filter_with_other_filters():
|
||||||
|
"""Test adding range filter alongside other filters"""
|
||||||
|
filters = [{"key": "tags", "match": {"any": ["tag1"]}}]
|
||||||
|
result = merge_range_filter(filters, "min_size", 100)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
{"key": "tags", "match": {"any": ["tag1"]}},
|
||||||
|
{"key": "size", "range": {"gte": 100}},
|
||||||
|
]
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"key,expected_direction,expected_field",
|
||||||
|
[
|
||||||
|
("min_sent_at", "min", "sent_at"),
|
||||||
|
("max_sent_at", "max", "sent_at"),
|
||||||
|
("min_published", "min", "published"),
|
||||||
|
("max_published", "max", "published"),
|
||||||
|
("min_size", "min", "size"),
|
||||||
|
("max_size", "max", "size"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_merge_range_filter_key_parsing(key, expected_direction, expected_field):
|
||||||
|
"""Test that field names are correctly extracted from keys"""
|
||||||
|
filters = []
|
||||||
|
result = merge_range_filter(filters, key, 100)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["key"] == expected_field
|
||||||
|
range_key = "gte" if expected_direction == "min" else "lte"
|
||||||
|
assert result[0]["range"][range_key] == 100
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"filter_key,filter_value",
|
||||||
|
[
|
||||||
|
("tags", ["tag1", "tag2"]),
|
||||||
|
("recipients", ["user1", "user2"]),
|
||||||
|
("observation_types", ["belief", "preference"]),
|
||||||
|
("authors", ["author1"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_merge_filters_list_filters(filter_key, filter_value):
|
||||||
|
"""Test list filters that use match any"""
|
||||||
|
filters = []
|
||||||
|
result = merge_filters(filters, filter_key, filter_value)
|
||||||
|
expected = [{"key": filter_key, "match": {"any": filter_value}}]
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters_min_confidences():
|
||||||
|
"""Test min_confidences filter creates multiple range conditions"""
|
||||||
|
filters = []
|
||||||
|
confidences = {"observation_accuracy": 0.8, "source_reliability": 0.9}
|
||||||
|
result = merge_filters(filters, "min_confidences", confidences)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
{"key": "confidence.observation_accuracy", "range": {"gte": 0.8}},
|
||||||
|
{"key": "confidence.source_reliability", "range": {"gte": 0.9}},
|
||||||
|
]
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters_source_ids():
|
||||||
|
"""Test source_ids filter maps to id field"""
|
||||||
|
filters = []
|
||||||
|
result = merge_filters(filters, "source_ids", ["id1", "id2"])
|
||||||
|
expected = [{"key": "id", "match": {"any": ["id1", "id2"]}}]
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters_range_delegation():
|
||||||
|
"""Test range filters are properly delegated to merge_range_filter"""
|
||||||
|
filters = []
|
||||||
|
result = merge_filters(filters, "min_size", 100)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "range" in result[0]
|
||||||
|
assert result[0]["range"]["gte"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters_combined_range():
|
||||||
|
"""Test that min/max range pairs merge into single filter"""
|
||||||
|
filters = []
|
||||||
|
filters = merge_filters(filters, "min_size", 100)
|
||||||
|
filters = merge_filters(filters, "max_size", 1000)
|
||||||
|
|
||||||
|
size_filters = [f for f in filters if f["key"] == "size"]
|
||||||
|
assert len(size_filters) == 1
|
||||||
|
assert size_filters[0]["range"]["gte"] == 100
|
||||||
|
assert size_filters[0]["range"]["lte"] == 1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters_preserves_existing():
|
||||||
|
"""Test that existing filters are preserved when adding new ones"""
|
||||||
|
existing_filters = [{"key": "existing", "match": "value"}]
|
||||||
|
result = merge_filters(existing_filters, "tags", ["new_tag"])
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert {"key": "existing", "match": "value"} in result
|
||||||
|
assert {"key": "tags", "match": {"any": ["new_tag"]}} in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters_realistic_combination():
|
||||||
|
"""Test a realistic filter combination for knowledge base search"""
|
||||||
|
filters = []
|
||||||
|
|
||||||
|
# Add typical knowledge base filters
|
||||||
|
filters = merge_filters(filters, "tags", ["important", "work"])
|
||||||
|
filters = merge_filters(filters, "min_published", "2023-01-01")
|
||||||
|
filters = merge_filters(filters, "max_size", 1000000) # 1MB max
|
||||||
|
filters = merge_filters(filters, "min_confidences", {"observation_accuracy": 0.8})
|
||||||
|
|
||||||
|
assert len(filters) == 4
|
||||||
|
|
||||||
|
# Check each filter type
|
||||||
|
tag_filter = next(f for f in filters if f["key"] == "tags")
|
||||||
|
assert tag_filter["match"]["any"] == ["important", "work"]
|
||||||
|
|
||||||
|
published_filter = next(f for f in filters if f["key"] == "published")
|
||||||
|
assert published_filter["range"]["gte"] == "2023-01-01"
|
||||||
|
|
||||||
|
size_filter = next(f for f in filters if f["key"] == "size")
|
||||||
|
assert size_filter["range"]["lte"] == 1000000
|
||||||
|
|
||||||
|
confidence_filter = next(
|
||||||
|
f for f in filters if f["key"] == "confidence.observation_accuracy"
|
||||||
|
)
|
||||||
|
assert confidence_filter["range"]["gte"] == 0.8
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters_unknown_key():
|
||||||
|
"""Test fallback behavior for unknown filter keys"""
|
||||||
|
filters = []
|
||||||
|
result = merge_filters(filters, "unknown_field", "unknown_value")
|
||||||
|
expected = [{"key": "unknown_field", "match": "unknown_value"}]
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters_empty_min_confidences():
|
||||||
|
"""Test min_confidences with empty dict does nothing"""
|
||||||
|
filters = []
|
||||||
|
result = merge_filters(filters, "min_confidences", {})
|
||||||
|
assert result == []
|
Loading…
x
Reference in New Issue
Block a user