mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-29 14:16:09 +02:00
Compare commits
4 Commits
780e27ba04
...
5e836337e2
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5e836337e2 | ||
![]() |
55809f3980 | ||
![]() |
0e574542d5 | ||
![]() |
3e4e5872d1 |
@ -49,6 +49,7 @@ x-worker-base: &worker-base
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
OPENAI_API_KEY_FILE: /run/secrets/openai_key
|
||||
ANTHROPIC_API_KEY_FILE: /run/secrets/anthropic_key
|
||||
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
||||
secrets: [ postgres_password, openai_key, anthropic_key, ssh_private_key, ssh_public_key, ssh_known_hosts ]
|
||||
read_only: true
|
||||
tmpfs:
|
||||
@ -150,6 +151,7 @@ services:
|
||||
SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
||||
VITE_SERVER_URL: "${SERVER_URL:-http://localhost:8000}"
|
||||
STATIC_DIR: "/app/static"
|
||||
VOYAGE_API_KEY: ${VOYAGE_API_KEY}
|
||||
secrets: [postgres_password]
|
||||
volumes:
|
||||
- ./memory_files:/app/memory_files:rw
|
||||
|
@ -285,6 +285,7 @@ body {
|
||||
display: flex;
|
||||
gap: 1rem;
|
||||
align-items: center;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.search-input {
|
||||
@ -323,6 +324,34 @@ body {
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.search-options {
|
||||
border-top: 1px solid #e2e8f0;
|
||||
padding-top: 1.5rem;
|
||||
display: grid;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.search-option {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.search-option label {
|
||||
font-weight: 500;
|
||||
color: #4a5568;
|
||||
font-size: 0.9rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.search-option input[type="checkbox"] {
|
||||
margin-right: 0.5rem;
|
||||
transform: scale(1.1);
|
||||
accent-color: #667eea;
|
||||
}
|
||||
|
||||
/* Search Results */
|
||||
.search-results {
|
||||
margin-top: 2rem;
|
||||
@ -429,6 +458,11 @@ body {
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.tag.selected {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.tag:hover {
|
||||
background: #e2e8f0;
|
||||
}
|
||||
@ -495,6 +529,14 @@ body {
|
||||
padding: 1.5rem;
|
||||
}
|
||||
|
||||
.search-options {
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.limit-input {
|
||||
width: 80px;
|
||||
}
|
||||
|
||||
.search-result-card {
|
||||
padding: 1rem;
|
||||
}
|
||||
@ -689,4 +731,194 @@ body {
|
||||
|
||||
.markdown-preview a:hover {
|
||||
color: #2b6cb0;
|
||||
}
|
||||
|
||||
/* Dynamic filters styles */
|
||||
.modality-filters {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
.modality-filter-group {
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 8px;
|
||||
background: white;
|
||||
}
|
||||
|
||||
.modality-filter-title {
|
||||
padding: 0.75rem 1rem;
|
||||
font-weight: 600;
|
||||
color: #4a5568;
|
||||
background: #f8fafc;
|
||||
border-radius: 7px 7px 0 0;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.modality-filter-title:hover {
|
||||
background: #edf2f7;
|
||||
}
|
||||
|
||||
.modality-filter-group[open] .modality-filter-title {
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
border-radius: 7px 7px 0 0;
|
||||
}
|
||||
|
||||
.filters-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 1rem;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
.filter-field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.filter-label {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
color: #4a5568;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.filter-input {
|
||||
padding: 0.5rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
background: white;
|
||||
transition: border-color 0.2s;
|
||||
}
|
||||
|
||||
.filter-input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
box-shadow: 0 0 0 1px #667eea;
|
||||
}
|
||||
|
||||
/* Selectable tags controls */
|
||||
.selectable-tags-details {
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 8px;
|
||||
background: white;
|
||||
}
|
||||
|
||||
.selectable-tags-summary {
|
||||
padding: 0.75rem 1rem;
|
||||
font-weight: 600;
|
||||
color: #4a5568;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s;
|
||||
border-radius: 7px;
|
||||
}
|
||||
|
||||
.selectable-tags-summary:hover {
|
||||
background: #f8fafc;
|
||||
}
|
||||
|
||||
.selectable-tags-details[open] .selectable-tags-summary {
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
border-radius: 7px 7px 0 0;
|
||||
background: #f8fafc;
|
||||
}
|
||||
|
||||
.tag-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.75rem 1rem;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
background: #fafbfc;
|
||||
}
|
||||
|
||||
.tag-control-btn {
|
||||
background: #f7fafc;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 4px;
|
||||
padding: 0.25rem 0.5rem;
|
||||
font-size: 0.75rem;
|
||||
color: #4a5568;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.tag-control-btn:hover:not(:disabled) {
|
||||
background: #edf2f7;
|
||||
border-color: #cbd5e0;
|
||||
}
|
||||
|
||||
.tag-control-btn:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.tag-count {
|
||||
font-size: 0.75rem;
|
||||
color: #718096;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* Tag search controls */
|
||||
.tag-search-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
padding: 0.75rem 1rem;
|
||||
background: #f8fafc;
|
||||
border-bottom: 1px solid #e2e8f0;
|
||||
}
|
||||
|
||||
.tag-search-input {
|
||||
flex: 1;
|
||||
padding: 0.375rem 0.5rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 4px;
|
||||
font-size: 0.875rem;
|
||||
background: white;
|
||||
transition: border-color 0.2s;
|
||||
}
|
||||
|
||||
.tag-search-input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
box-shadow: 0 0 0 1px #667eea;
|
||||
}
|
||||
|
||||
|
||||
|
||||
.filtered-count {
|
||||
font-size: 0.75rem;
|
||||
color: #718096;
|
||||
font-weight: 500;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.tags-display-area {
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.filters-grid {
|
||||
grid-template-columns: 1fr;
|
||||
gap: 0.5rem;
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.modality-filter-title {
|
||||
padding: 0.5rem 0.75rem;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.tag-search-controls {
|
||||
flex-direction: column;
|
||||
align-items: stretch;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
}
|
@ -2,9 +2,9 @@ import { useEffect } from 'react'
|
||||
import { BrowserRouter as Router, Routes, Route, Navigate, useNavigate, useLocation } from 'react-router-dom'
|
||||
import './App.css'
|
||||
|
||||
import { useAuth } from './hooks/useAuth'
|
||||
import { useOAuth } from './hooks/useOAuth'
|
||||
import { Loading, LoginPrompt, AuthError, Dashboard, Search } from './components'
|
||||
import { useAuth } from '@/hooks/useAuth'
|
||||
import { useOAuth } from '@/hooks/useOAuth'
|
||||
import { Loading, LoginPrompt, AuthError, Dashboard, Search } from '@/components'
|
||||
|
||||
// AuthWrapper handles redirects based on auth state
|
||||
const AuthWrapper = () => {
|
||||
@ -31,7 +31,10 @@ const AuthWrapper = () => {
|
||||
// Handle redirects based on auth state changes
|
||||
useEffect(() => {
|
||||
if (!isLoading) {
|
||||
if (isAuthenticated) {
|
||||
if (location.pathname === '/ui/logout') {
|
||||
logout()
|
||||
navigate('/ui/login', { replace: true })
|
||||
} else if (isAuthenticated) {
|
||||
// If authenticated and on login page, redirect to dashboard
|
||||
if (location.pathname === '/ui/login' || location.pathname === '/ui') {
|
||||
navigate('/ui/dashboard', { replace: true })
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { Link } from 'react-router-dom'
|
||||
import { useMCP } from '../hooks/useMCP'
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
|
||||
const Dashboard = ({ onLogout }) => {
|
||||
const { listNotes } = useMCP()
|
||||
|
@ -1,6 +1,4 @@
|
||||
import React from 'react'
|
||||
|
||||
const Loading = ({ message = "Loading..." }) => {
|
||||
const Loading = ({ message = "Loading..." }: { message?: string }) => {
|
||||
return (
|
||||
<div className="loading">
|
||||
<h2>{message}</h2>
|
||||
|
@ -1,6 +1,4 @@
|
||||
import React from 'react'
|
||||
|
||||
const AuthError = ({ error, onRetry }) => {
|
||||
const AuthError = ({ error, onRetry }: { error: string, onRetry: () => void }) => {
|
||||
return (
|
||||
<div className="error">
|
||||
<h2>Authentication Error</h2>
|
||||
|
@ -1,6 +1,4 @@
|
||||
import React from 'react'
|
||||
|
||||
const LoginPrompt = ({ onLogin }) => {
|
||||
const LoginPrompt = ({ onLogin }: { onLogin: () => void }) => {
|
||||
return (
|
||||
<div className="login-prompt">
|
||||
<h1>Memory App</h1>
|
||||
|
@ -1,5 +1,5 @@
|
||||
export { default as Loading } from './Loading'
|
||||
export { default as Dashboard } from './Dashboard'
|
||||
export { default as Search } from './Search'
|
||||
export { default as Search } from './search'
|
||||
export { default as LoginPrompt } from './auth/LoginPrompt'
|
||||
export { default as AuthError } from './auth/AuthError'
|
155
frontend/src/components/search/DynamicFilters.tsx
Normal file
155
frontend/src/components/search/DynamicFilters.tsx
Normal file
@ -0,0 +1,155 @@
|
||||
import { FilterInput } from './FilterInput'
|
||||
import { CollectionMetadata, SchemaArg } from '@/types/mcp'
|
||||
|
||||
// Pure helper functions for schema processing
|
||||
const formatFieldLabel = (field: string): string =>
|
||||
field.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase())
|
||||
|
||||
const shouldSkipField = (fieldName: string): boolean =>
|
||||
['tags'].includes(fieldName)
|
||||
|
||||
const isCommonField = (fieldName: string): boolean =>
|
||||
['size', 'filename', 'content_type', 'mail_message_id', 'sent_at', 'created_at', 'source_id'].includes(fieldName)
|
||||
|
||||
const createMinMaxFields = (fieldName: string, fieldConfig: any): [string, any][] => [
|
||||
[`min_${fieldName}`, { ...fieldConfig, description: `Min ${fieldConfig.description}` }],
|
||||
[`max_${fieldName}`, { ...fieldConfig, description: `Max ${fieldConfig.description}` }]
|
||||
]
|
||||
|
||||
const createSizeFields = (): [string, any][] => [
|
||||
['min_size', { type: 'int', description: 'Minimum size in bytes' }],
|
||||
['max_size', { type: 'int', description: 'Maximum size in bytes' }]
|
||||
]
|
||||
|
||||
const extractSchemaFields = (schema: Record<string, SchemaArg>, includeCommon = true): [string, SchemaArg][] =>
|
||||
Object.entries(schema)
|
||||
.filter(([fieldName]) => !shouldSkipField(fieldName))
|
||||
.filter(([fieldName]) => includeCommon || !isCommonField(fieldName))
|
||||
.flatMap(([fieldName, fieldConfig]) =>
|
||||
['sent_at', 'published', 'created_at'].includes(fieldName)
|
||||
? createMinMaxFields(fieldName, fieldConfig)
|
||||
: [[fieldName, fieldConfig] as [string, SchemaArg]]
|
||||
)
|
||||
|
||||
const getCommonFields = (schemas: Record<string, CollectionMetadata>, selectedModalities: string[]): [string, SchemaArg][] => {
|
||||
const commonFieldsMap = new Map<string, SchemaArg>()
|
||||
|
||||
// Manually add created_at fields even if not in schema
|
||||
createMinMaxFields('created_at', { type: 'datetime', description: 'Creation date' }).forEach(([field, config]) => {
|
||||
commonFieldsMap.set(field, config)
|
||||
})
|
||||
|
||||
selectedModalities.forEach(modality => {
|
||||
const schema = schemas[modality].schema
|
||||
if (!schema) return
|
||||
|
||||
Object.entries(schema).forEach(([fieldName, fieldConfig]) => {
|
||||
if (isCommonField(fieldName)) {
|
||||
if (['sent_at', 'created_at'].includes(fieldName)) {
|
||||
createMinMaxFields(fieldName, fieldConfig).forEach(([field, config]) => {
|
||||
commonFieldsMap.set(field, config)
|
||||
})
|
||||
} else if (fieldName === 'size') {
|
||||
createSizeFields().forEach(([field, config]) => {
|
||||
commonFieldsMap.set(field, config)
|
||||
})
|
||||
} else {
|
||||
commonFieldsMap.set(fieldName, fieldConfig)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return Array.from(commonFieldsMap.entries())
|
||||
}
|
||||
|
||||
const getModalityFields = (schemas: Record<string, CollectionMetadata>, selectedModalities: string[]): Record<string, [string, SchemaArg][]> => {
|
||||
return selectedModalities.reduce((acc, modality) => {
|
||||
const schema = schemas[modality].schema
|
||||
if (!schema) return acc
|
||||
|
||||
const schemaFields = extractSchemaFields(schema, false) // Exclude common fields
|
||||
|
||||
if (schemaFields.length > 0) {
|
||||
acc[modality] = schemaFields
|
||||
}
|
||||
|
||||
return acc
|
||||
}, {} as Record<string, [string, SchemaArg][]>)
|
||||
}
|
||||
|
||||
interface DynamicFiltersProps {
|
||||
schemas: Record<string, CollectionMetadata>
|
||||
selectedModalities: string[]
|
||||
filters: Record<string, SchemaArg>
|
||||
onFilterChange: (field: string, value: SchemaArg) => void
|
||||
}
|
||||
|
||||
export const DynamicFilters = ({
|
||||
schemas,
|
||||
selectedModalities,
|
||||
filters,
|
||||
onFilterChange
|
||||
}: DynamicFiltersProps) => {
|
||||
const commonFields = getCommonFields(schemas, selectedModalities)
|
||||
const modalityFields = getModalityFields(schemas, selectedModalities)
|
||||
|
||||
if (commonFields.length === 0 && Object.keys(modalityFields).length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="search-option">
|
||||
<label>Filters:</label>
|
||||
<div className="modality-filters">
|
||||
{/* Common/Document Properties Section */}
|
||||
{commonFields.length > 0 && (
|
||||
<details className="modality-filter-group" open>
|
||||
<summary className="modality-filter-title">
|
||||
Document Properties
|
||||
</summary>
|
||||
<div className="filters-grid">
|
||||
{commonFields.map(([field, fieldConfig]: [string, SchemaArg]) => (
|
||||
<div key={field} className="filter-field">
|
||||
<label className="filter-label">
|
||||
{formatFieldLabel(field)}:
|
||||
</label>
|
||||
<FilterInput
|
||||
field={field}
|
||||
fieldConfig={fieldConfig}
|
||||
value={filters[field]}
|
||||
onChange={onFilterChange}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
)}
|
||||
|
||||
{/* Modality-specific sections */}
|
||||
{Object.entries(modalityFields).map(([modality, fields]) => (
|
||||
<details key={modality} className="modality-filter-group">
|
||||
<summary className="modality-filter-title">
|
||||
{formatFieldLabel(modality)} Specific
|
||||
</summary>
|
||||
<div className="filters-grid">
|
||||
{fields.map(([field, fieldConfig]: [string, SchemaArg]) => (
|
||||
<div key={field} className="filter-field">
|
||||
<label className="filter-label">
|
||||
{formatFieldLabel(field)}:
|
||||
</label>
|
||||
<FilterInput
|
||||
field={field}
|
||||
fieldConfig={fieldConfig}
|
||||
value={filters[field]}
|
||||
onChange={onFilterChange}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
53
frontend/src/components/search/FilterInput.tsx
Normal file
53
frontend/src/components/search/FilterInput.tsx
Normal file
@ -0,0 +1,53 @@
|
||||
// Pure helper functions
|
||||
const isDateField = (field: string): boolean =>
|
||||
field.includes('sent_at') || field.includes('published') || field.includes('created_at')
|
||||
|
||||
const isNumberField = (field: string, fieldConfig: any): boolean =>
|
||||
fieldConfig.type?.includes('int') || field.includes('size')
|
||||
|
||||
const parseNumberValue = (value: string): number | null =>
|
||||
value ? parseInt(value) : null
|
||||
|
||||
interface FilterInputProps {
|
||||
field: string
|
||||
fieldConfig: any
|
||||
value: any
|
||||
onChange: (field: string, value: any) => void
|
||||
}
|
||||
|
||||
export const FilterInput = ({ field, fieldConfig, value, onChange }: FilterInputProps) => {
|
||||
const inputProps = {
|
||||
value: value || '',
|
||||
className: "filter-input"
|
||||
}
|
||||
|
||||
if (isNumberField(field, fieldConfig)) {
|
||||
return (
|
||||
<input
|
||||
{...inputProps}
|
||||
type="number"
|
||||
onChange={(e) => onChange(field, parseNumberValue(e.target.value))}
|
||||
placeholder={fieldConfig.description}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (isDateField(field)) {
|
||||
return (
|
||||
<input
|
||||
{...inputProps}
|
||||
type="date"
|
||||
onChange={(e) => onChange(field, e.target.value || null)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<input
|
||||
{...inputProps}
|
||||
type="text"
|
||||
onChange={(e) => onChange(field, e.target.value || null)}
|
||||
placeholder={fieldConfig.description}
|
||||
/>
|
||||
)
|
||||
}
|
69
frontend/src/components/search/Search.tsx
Normal file
69
frontend/src/components/search/Search.tsx
Normal file
@ -0,0 +1,69 @@
|
||||
import { useState } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
import Loading from '@/components/Loading'
|
||||
import { SearchResult } from './results'
|
||||
import SearchForm, { SearchParams } from './SearchForm'
|
||||
|
||||
const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => {
|
||||
if (isLoading) {
|
||||
return <Loading message="Searching..." />
|
||||
}
|
||||
return (
|
||||
<div className="search-results">
|
||||
{results.length > 0 && (
|
||||
<div className="results-count">
|
||||
Found {results.length} result{results.length !== 1 ? 's' : ''}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{results.map((result, index) => <SearchResult key={index} result={result} />)}
|
||||
|
||||
{results.length === 0 && (
|
||||
<div className="no-results">
|
||||
No results found
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
const Search = () => {
|
||||
const navigate = useNavigate()
|
||||
const [results, setResults] = useState([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const { searchKnowledgeBase } = useMCP()
|
||||
|
||||
const handleSearch = async (params: SearchParams) => {
|
||||
if (!params.query.trim()) return
|
||||
|
||||
setIsLoading(true)
|
||||
try {
|
||||
console.log(params)
|
||||
const searchResults = await searchKnowledgeBase(params.query, params.previews, params.limit, params.filters, params.modalities)
|
||||
setResults(searchResults || [])
|
||||
} catch (error) {
|
||||
console.error('Search error:', error)
|
||||
setResults([])
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="search-view">
|
||||
<header className="search-header">
|
||||
<button onClick={() => navigate('/ui/dashboard')} className="back-btn">
|
||||
← Back to Dashboard
|
||||
</button>
|
||||
<h2>🔍 Search Knowledge Base</h2>
|
||||
</header>
|
||||
|
||||
<SearchForm isLoading={isLoading} onSearch={handleSearch} />
|
||||
<SearchResults results={results} isLoading={isLoading} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Search
|
151
frontend/src/components/search/SearchForm.tsx
Normal file
151
frontend/src/components/search/SearchForm.tsx
Normal file
@ -0,0 +1,151 @@
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { DynamicFilters } from './DynamicFilters'
|
||||
import { SelectableTags } from './SelectableTags'
|
||||
import { CollectionMetadata } from '@/types/mcp'
|
||||
|
||||
type Filter = {
|
||||
tags?: string[]
|
||||
source_ids?: string[]
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
export interface SearchParams {
|
||||
query: string
|
||||
previews: boolean
|
||||
modalities: string[]
|
||||
filters: Filter
|
||||
limit: number
|
||||
}
|
||||
|
||||
interface SearchFormProps {
|
||||
isLoading: boolean
|
||||
onSearch: (params: SearchParams) => void
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Pure helper functions for SearchForm
|
||||
const createFlags = (items: string[], defaultValue = false): Record<string, boolean> =>
|
||||
items.reduce((acc, item) => ({ ...acc, [item]: defaultValue }), {})
|
||||
|
||||
const getSelectedItems = (items: Record<string, boolean>): string[] =>
|
||||
Object.entries(items).filter(([_, selected]) => selected).map(([key]) => key)
|
||||
|
||||
const cleanFilters = (filters: Record<string, any>): Record<string, any> =>
|
||||
Object.entries(filters)
|
||||
.filter(([_, value]) => value !== null && value !== '' && value !== undefined)
|
||||
.reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {})
|
||||
|
||||
export const SearchForm = ({ isLoading, onSearch }: SearchFormProps) => {
|
||||
const [query, setQuery] = useState('')
|
||||
const [previews, setPreviews] = useState(false)
|
||||
const [modalities, setModalities] = useState<Record<string, boolean>>({})
|
||||
const [schemas, setSchemas] = useState<Record<string, CollectionMetadata>>({})
|
||||
const [tags, setTags] = useState<Record<string, boolean>>({})
|
||||
const [dynamicFilters, setDynamicFilters] = useState<Record<string, any>>({})
|
||||
const [limit, setLimit] = useState(10)
|
||||
const { getMetadataSchemas, getTags } = useMCP()
|
||||
|
||||
useEffect(() => {
|
||||
const setupFilters = async () => {
|
||||
const [schemas, tags] = await Promise.all([
|
||||
getMetadataSchemas(),
|
||||
getTags()
|
||||
])
|
||||
setSchemas(schemas)
|
||||
setModalities(createFlags(Object.keys(schemas), true))
|
||||
setTags(createFlags(tags))
|
||||
}
|
||||
setupFilters()
|
||||
}, [getMetadataSchemas, getTags])
|
||||
|
||||
const handleFilterChange = (field: string, value: any) =>
|
||||
setDynamicFilters(prev => ({ ...prev, [field]: value }))
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
|
||||
onSearch({
|
||||
query,
|
||||
previews,
|
||||
modalities: getSelectedItems(modalities),
|
||||
filters: {
|
||||
tags: getSelectedItems(tags),
|
||||
...cleanFilters(dynamicFilters)
|
||||
},
|
||||
limit
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className="search-form">
|
||||
<div className="search-input-group">
|
||||
<input
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
placeholder="Search your knowledge base..."
|
||||
className="search-input"
|
||||
required
|
||||
/>
|
||||
<button type="submit" disabled={isLoading} className="search-btn">
|
||||
{isLoading ? 'Searching...' : 'Search'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="search-options">
|
||||
<div className="search-option">
|
||||
<label>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={previews}
|
||||
onChange={(e) => setPreviews(e.target.checked)}
|
||||
/>
|
||||
Include content previews
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<SelectableTags
|
||||
title="Modalities"
|
||||
className="modality-checkboxes"
|
||||
tags={modalities}
|
||||
onSelect={(tag, selected) => setModalities({ ...modalities, [tag]: selected })}
|
||||
onBatchUpdate={(updates) => setModalities(updates)}
|
||||
/>
|
||||
|
||||
<SelectableTags
|
||||
title="Tags"
|
||||
className="tags-container"
|
||||
tags={tags}
|
||||
onSelect={(tag, selected) => setTags({ ...tags, [tag]: selected })}
|
||||
onBatchUpdate={(updates) => setTags(updates)}
|
||||
searchable={true}
|
||||
/>
|
||||
|
||||
<DynamicFilters
|
||||
schemas={schemas}
|
||||
selectedModalities={getSelectedItems(modalities)}
|
||||
filters={dynamicFilters}
|
||||
onFilterChange={handleFilterChange}
|
||||
/>
|
||||
|
||||
<div className="search-option">
|
||||
<label>
|
||||
Max Results:
|
||||
<input
|
||||
type="number"
|
||||
value={limit}
|
||||
onChange={(e) => setLimit(parseInt(e.target.value) || 10)}
|
||||
min={1}
|
||||
max={100}
|
||||
className="limit-input"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
export default SearchForm
|
136
frontend/src/components/search/SelectableTags.tsx
Normal file
136
frontend/src/components/search/SelectableTags.tsx
Normal file
@ -0,0 +1,136 @@
|
||||
interface SelectableTagProps {
|
||||
tag: string
|
||||
selected: boolean
|
||||
onSelect: (tag: string, selected: boolean) => void
|
||||
}
|
||||
|
||||
const SelectableTag = ({ tag, selected, onSelect }: SelectableTagProps) => {
|
||||
return (
|
||||
<span
|
||||
className={`tag ${selected ? 'selected' : ''}`}
|
||||
onClick={() => onSelect(tag, !selected)}
|
||||
>
|
||||
{tag}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
import { useState } from 'react'
|
||||
|
||||
interface SelectableTagsProps {
|
||||
title: string
|
||||
className: string
|
||||
tags: Record<string, boolean>
|
||||
onSelect: (tag: string, selected: boolean) => void
|
||||
onBatchUpdate?: (updates: Record<string, boolean>) => void
|
||||
searchable?: boolean
|
||||
}
|
||||
|
||||
export const SelectableTags = ({ title, className, tags, onSelect, onBatchUpdate, searchable = false }: SelectableTagsProps) => {
|
||||
const [searchTerm, setSearchTerm] = useState('')
|
||||
const handleSelectAll = () => {
|
||||
if (onBatchUpdate) {
|
||||
const updates = Object.keys(tags).reduce((acc, tag) => {
|
||||
acc[tag] = true
|
||||
return acc
|
||||
}, {} as Record<string, boolean>)
|
||||
onBatchUpdate(updates)
|
||||
} else {
|
||||
// Fallback to individual updates (though this won't work well)
|
||||
Object.keys(tags).forEach(tag => {
|
||||
if (!tags[tag]) {
|
||||
onSelect(tag, true)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const handleDeselectAll = () => {
|
||||
if (onBatchUpdate) {
|
||||
const updates = Object.keys(tags).reduce((acc, tag) => {
|
||||
acc[tag] = false
|
||||
return acc
|
||||
}, {} as Record<string, boolean>)
|
||||
onBatchUpdate(updates)
|
||||
} else {
|
||||
// Fallback to individual updates (though this won't work well)
|
||||
Object.keys(tags).forEach(tag => {
|
||||
if (tags[tag]) {
|
||||
onSelect(tag, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Filter tags based on search term
|
||||
const filteredTags = Object.entries(tags).filter(([tag, selected]) => {
|
||||
return !searchTerm || tag.toLowerCase().includes(searchTerm.toLowerCase())
|
||||
})
|
||||
|
||||
const selectedCount = Object.values(tags).filter(Boolean).length
|
||||
const totalCount = Object.keys(tags).length
|
||||
const filteredSelectedCount = filteredTags.filter(([_, selected]) => selected).length
|
||||
const filteredTotalCount = filteredTags.length
|
||||
const allSelected = selectedCount === totalCount
|
||||
const noneSelected = selectedCount === 0
|
||||
|
||||
return (
|
||||
<div className="search-option">
|
||||
<details className="selectable-tags-details">
|
||||
<summary className="selectable-tags-summary">
|
||||
{title} ({selectedCount} selected)
|
||||
</summary>
|
||||
|
||||
<div className="tag-controls">
|
||||
<button
|
||||
type="button"
|
||||
className="tag-control-btn"
|
||||
onClick={handleSelectAll}
|
||||
disabled={allSelected}
|
||||
>
|
||||
All
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
className="tag-control-btn"
|
||||
onClick={handleDeselectAll}
|
||||
disabled={noneSelected}
|
||||
>
|
||||
None
|
||||
</button>
|
||||
<span className="tag-count">
|
||||
({selectedCount}/{totalCount})
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{searchable && (
|
||||
<div className="tag-search-controls">
|
||||
<input
|
||||
type="text"
|
||||
placeholder={`Search ${title.toLowerCase()}...`}
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className="tag-search-input"
|
||||
/>
|
||||
{searchTerm && (
|
||||
<span className="filtered-count">
|
||||
Showing {filteredSelectedCount}/{filteredTotalCount}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className={`${className} tags-display-area`}>
|
||||
{filteredTags.map(([tag, selected]: [string, boolean]) => (
|
||||
<SelectableTag
|
||||
key={tag}
|
||||
tag={tag}
|
||||
selected={selected}
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
)
|
||||
}
|
3
frontend/src/components/search/index.js
Normal file
3
frontend/src/components/search/index.js
Normal file
@ -0,0 +1,3 @@
|
||||
import Search from './Search'
|
||||
|
||||
export default Search
|
@ -1,10 +1,8 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { useState, useEffect } from 'react'
|
||||
import ReactMarkdown from 'react-markdown'
|
||||
import { useMCP } from '../hooks/useMCP'
|
||||
import Loading from './Loading'
|
||||
import { useMCP } from '@/hooks/useMCP'
|
||||
|
||||
type SearchItem = {
|
||||
export type SearchItem = {
|
||||
filename: string
|
||||
content: string
|
||||
chunks: any[]
|
||||
@ -13,7 +11,7 @@ type SearchItem = {
|
||||
metadata: any
|
||||
}
|
||||
|
||||
const Tag = ({ tags }: { tags: string[] }) => {
|
||||
export const Tag = ({ tags }: { tags: string[] }) => {
|
||||
return (
|
||||
<div className="tags">
|
||||
{tags?.map((tag: string, index: number) => (
|
||||
@ -23,11 +21,12 @@ const Tag = ({ tags }: { tags: string[] }) => {
|
||||
)
|
||||
}
|
||||
|
||||
const TextResult = ({ filename, content, chunks, tags }: SearchItem) => {
|
||||
export const TextResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{filename || 'Untitled'}</h4>
|
||||
<Tag tags={tags} />
|
||||
<Metadata metadata={metadata} />
|
||||
<p className="result-content">{content || 'No content available'}</p>
|
||||
{chunks && chunks.length > 0 && (
|
||||
<details className="result-chunks">
|
||||
@ -44,7 +43,7 @@ const TextResult = ({ filename, content, chunks, tags }: SearchItem) => {
|
||||
)
|
||||
}
|
||||
|
||||
const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
||||
export const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchItem) => {
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{filename || 'Untitled'}</h4>
|
||||
@ -70,7 +69,7 @@ const MarkdownResult = ({ filename, content, chunks, tags, metadata }: SearchIte
|
||||
)
|
||||
}
|
||||
|
||||
const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => {
|
||||
export const ImageResult = ({ filename, tags, metadata }: SearchItem) => {
|
||||
const title = metadata?.title || filename || 'Untitled'
|
||||
const { fetchFile } = useMCP()
|
||||
const [mime_type, setMimeType] = useState<string>()
|
||||
@ -95,7 +94,7 @@ const ImageResult = ({ filename, chunks, tags, metadata }: SearchItem) => {
|
||||
)
|
||||
}
|
||||
|
||||
const Metadata = ({ metadata }: { metadata: any }) => {
|
||||
export const Metadata = ({ metadata }: { metadata: any }) => {
|
||||
if (!metadata) return null
|
||||
return (
|
||||
<div className="metadata">
|
||||
@ -108,7 +107,7 @@ const Metadata = ({ metadata }: { metadata: any }) => {
|
||||
)
|
||||
}
|
||||
|
||||
const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
||||
export const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{filename || 'Untitled'}</h4>
|
||||
@ -125,7 +124,7 @@ const PDFResult = ({ filename, content, tags, metadata }: SearchItem) => {
|
||||
)
|
||||
}
|
||||
|
||||
const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
||||
export const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
||||
return (
|
||||
<div className="search-result-card">
|
||||
<h4>{metadata?.title || metadata?.subject || 'Untitled'}</h4>
|
||||
@ -138,7 +137,7 @@ const EmailResult = ({ content, tags, metadata }: SearchItem) => {
|
||||
)
|
||||
}
|
||||
|
||||
const SearchResult = ({ result }: { result: SearchItem }) => {
|
||||
export const SearchResult = ({ result }: { result: SearchItem }) => {
|
||||
if (result.mime_type.startsWith('image/')) {
|
||||
return <ImageResult {...result} />
|
||||
}
|
||||
@ -158,86 +157,4 @@ const SearchResult = ({ result }: { result: SearchItem }) => {
|
||||
return null
|
||||
}
|
||||
|
||||
const SearchResults = ({ results, isLoading }: { results: any[], isLoading: boolean }) => {
|
||||
if (isLoading) {
|
||||
return <Loading message="Searching..." />
|
||||
}
|
||||
return (
|
||||
<div className="search-results">
|
||||
{results.length > 0 && (
|
||||
<div className="results-count">
|
||||
Found {results.length} result{results.length !== 1 ? 's' : ''}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{results.map((result, index) => <SearchResult key={index} result={result} />)}
|
||||
|
||||
{results.length === 0 && (
|
||||
<div className="no-results">
|
||||
No results found
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const SearchForm = ({ isLoading, onSearch }: { isLoading: boolean, onSearch: (query: string) => void }) => {
|
||||
const [query, setQuery] = useState('')
|
||||
return (
|
||||
<form onSubmit={(e) => {
|
||||
e.preventDefault()
|
||||
onSearch(query)
|
||||
}} className="search-form">
|
||||
<div className="search-input-group">
|
||||
<input
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
placeholder="Search your knowledge base..."
|
||||
className="search-input"
|
||||
/>
|
||||
<button type="submit" disabled={isLoading} className="search-btn">
|
||||
{isLoading ? 'Searching...' : 'Search'}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
const Search = () => {
|
||||
const navigate = useNavigate()
|
||||
const [results, setResults] = useState([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const { searchKnowledgeBase } = useMCP()
|
||||
|
||||
const handleSearch = async (query: string) => {
|
||||
if (!query.trim()) return
|
||||
|
||||
setIsLoading(true)
|
||||
try {
|
||||
const searchResults = await searchKnowledgeBase(query)
|
||||
setResults(searchResults || [])
|
||||
} catch (error) {
|
||||
console.error('Search error:', error)
|
||||
setResults([])
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="search-view">
|
||||
<header className="search-header">
|
||||
<button onClick={() => navigate('/ui/dashboard')} className="back-btn">
|
||||
← Back to Dashboard
|
||||
</button>
|
||||
<h2>🔍 Search Knowledge Base</h2>
|
||||
</header>
|
||||
|
||||
<SearchForm isLoading={isLoading} onSearch={handleSearch} />
|
||||
<SearchResults results={results} isLoading={isLoading} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Search
|
||||
export default SearchResult
|
@ -4,20 +4,20 @@ const SERVER_URL = import.meta.env.VITE_SERVER_URL || 'http://localhost:8000'
|
||||
const SESSION_COOKIE_NAME = import.meta.env.VITE_SESSION_COOKIE_NAME || 'session_id'
|
||||
|
||||
// Cookie utilities
|
||||
const getCookie = (name) => {
|
||||
const getCookie = (name: string) => {
|
||||
const value = `; ${document.cookie}`
|
||||
const parts = value.split(`; ${name}=`)
|
||||
if (parts.length === 2) return parts.pop().split(';').shift()
|
||||
return null
|
||||
}
|
||||
|
||||
const setCookie = (name, value, days = 30) => {
|
||||
const setCookie = (name: string, value: string, days = 30) => {
|
||||
const expires = new Date()
|
||||
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
|
||||
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
|
||||
}
|
||||
|
||||
const deleteCookie = (name) => {
|
||||
const deleteCookie = (name: string) => {
|
||||
document.cookie = `${name}=;expires=Thu, 01 Jan 1970 00:00:01 GMT;path=/`
|
||||
}
|
||||
|
||||
@ -68,6 +68,7 @@ export const useAuth = () => {
|
||||
deleteCookie('access_token')
|
||||
deleteCookie('refresh_token')
|
||||
deleteCookie(SESSION_COOKIE_NAME)
|
||||
localStorage.removeItem('oauth_client_id')
|
||||
setIsAuthenticated(false)
|
||||
}, [])
|
||||
|
||||
@ -110,7 +111,7 @@ export const useAuth = () => {
|
||||
}, [logout])
|
||||
|
||||
// Make authenticated API calls with automatic token refresh
|
||||
const apiCall = useCallback(async (endpoint, options = {}) => {
|
||||
const apiCall = useCallback(async (endpoint: string, options: RequestInit = {}) => {
|
||||
let accessToken = getCookie('access_token')
|
||||
|
||||
if (!accessToken) {
|
||||
@ -122,7 +123,7 @@ export const useAuth = () => {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
const requestOptions = {
|
||||
const requestOptions: RequestInit & { headers: Record<string, string> } = {
|
||||
...options,
|
||||
headers: { ...defaultHeaders, ...options.headers },
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { useEffect, useCallback } from 'react'
|
||||
import { useAuth } from './useAuth'
|
||||
import { useAuth } from '@/hooks/useAuth'
|
||||
|
||||
const parseServerSentEvents = async (response: Response): Promise<any> => {
|
||||
const reader = response.body?.getReader()
|
||||
@ -91,10 +91,10 @@ const parseJsonRpcResponse = async (response: Response): Promise<any> => {
|
||||
}
|
||||
|
||||
export const useMCP = () => {
|
||||
const { apiCall, isAuthenticated, isLoading, checkAuth } = useAuth()
|
||||
const { apiCall, checkAuth } = useAuth()
|
||||
|
||||
const mcpCall = useCallback(async (path: string, method: string, params: any = {}) => {
|
||||
const response = await apiCall(`/mcp${path}`, {
|
||||
const mcpCall = useCallback(async (method: string, params: any = {}) => {
|
||||
const response = await apiCall(`/mcp/${method}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Accept': 'application/json, text/event-stream',
|
||||
@ -118,22 +118,46 @@ export const useMCP = () => {
|
||||
if (resp?.result?.isError) {
|
||||
throw new Error(resp?.result?.content[0].text)
|
||||
}
|
||||
return resp?.result?.content.map((item: any) => JSON.parse(item.text))
|
||||
return resp?.result?.content.map((item: any) => {
|
||||
try {
|
||||
return JSON.parse(item.text)
|
||||
} catch (e) {
|
||||
return item.text
|
||||
}
|
||||
})
|
||||
}, [apiCall])
|
||||
|
||||
const listNotes = useCallback(async (path: string = "/") => {
|
||||
return await mcpCall('/note_files', 'note_files', { path })
|
||||
return await mcpCall('note_files', { path })
|
||||
}, [mcpCall])
|
||||
|
||||
const fetchFile = useCallback(async (filename: string) => {
|
||||
return await mcpCall('/fetch_file', 'fetch_file', { filename })
|
||||
return await mcpCall('fetch_file', { filename })
|
||||
}, [mcpCall])
|
||||
|
||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10) => {
|
||||
return await mcpCall('/search_knowledge_base', 'search_knowledge_base', {
|
||||
const getTags = useCallback(async () => {
|
||||
return await mcpCall('get_all_tags')
|
||||
}, [mcpCall])
|
||||
|
||||
const getSubjects = useCallback(async () => {
|
||||
return await mcpCall('get_all_subjects')
|
||||
}, [mcpCall])
|
||||
|
||||
const getObservationTypes = useCallback(async () => {
|
||||
return await mcpCall('get_all_observation_types')
|
||||
}, [mcpCall])
|
||||
|
||||
const getMetadataSchemas = useCallback(async () => {
|
||||
return (await mcpCall('get_metadata_schemas'))[0]
|
||||
}, [mcpCall])
|
||||
|
||||
const searchKnowledgeBase = useCallback(async (query: string, previews: boolean = true, limit: number = 10, filters: Record<string, any> = {}, modalities: string[] = []) => {
|
||||
return await mcpCall('search_knowledge_base', {
|
||||
query,
|
||||
filters,
|
||||
modalities,
|
||||
previews,
|
||||
limit
|
||||
limit,
|
||||
})
|
||||
}, [mcpCall])
|
||||
|
||||
@ -146,5 +170,9 @@ export const useMCP = () => {
|
||||
fetchFile,
|
||||
listNotes,
|
||||
searchKnowledgeBase,
|
||||
getTags,
|
||||
getSubjects,
|
||||
getObservationTypes,
|
||||
getMetadataSchemas,
|
||||
}
|
||||
}
|
@ -14,7 +14,7 @@ const generateCodeVerifier = () => {
|
||||
.replace(/=/g, '')
|
||||
}
|
||||
|
||||
const generateCodeChallenge = async (verifier) => {
|
||||
const generateCodeChallenge = async (verifier: string) => {
|
||||
const data = new TextEncoder().encode(verifier)
|
||||
const digest = await crypto.subtle.digest('SHA-256', data)
|
||||
return btoa(String.fromCharCode(...new Uint8Array(digest)))
|
||||
@ -33,7 +33,7 @@ const generateState = () => {
|
||||
}
|
||||
|
||||
// Storage utilities
|
||||
const setCookie = (name, value, days = 30) => {
|
||||
const setCookie = (name: string, value: string, days = 30) => {
|
||||
const expires = new Date()
|
||||
expires.setTime(expires.getTime() + days * 24 * 60 * 60 * 1000)
|
||||
document.cookie = `${name}=${value};expires=${expires.toUTCString()};path=/;SameSite=Lax`
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { StrictMode } from 'react'
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import App from './App.jsx'
|
||||
import App from '@/App.jsx'
|
||||
|
||||
createRoot(document.getElementById('root')).render(
|
||||
<StrictMode>
|
||||
|
9
frontend/src/types/mcp.tsx
Normal file
9
frontend/src/types/mcp.tsx
Normal file
@ -0,0 +1,9 @@
|
||||
export type CollectionMetadata = {
|
||||
schema: Record<string, SchemaArg>
|
||||
size: number
|
||||
}
|
||||
|
||||
export type SchemaArg = {
|
||||
type: string
|
||||
description: string
|
||||
}
|
40
frontend/tsconfig.json
Normal file
40
frontend/tsconfig.json
Normal file
@ -0,0 +1,40 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"useDefineForClassFields": true,
|
||||
"lib": [
|
||||
"ES2020",
|
||||
"DOM",
|
||||
"DOM.Iterable"
|
||||
],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"noEmit": true,
|
||||
"jsx": "react-jsx",
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"noFallthroughCasesInSwitch": true,
|
||||
/* Absolute imports */
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"@/*": [
|
||||
"src/*"
|
||||
]
|
||||
}
|
||||
},
|
||||
"include": [
|
||||
"src"
|
||||
],
|
||||
"references": [
|
||||
{
|
||||
"path": "./tsconfig.node.json"
|
||||
}
|
||||
]
|
||||
}
|
12
frontend/tsconfig.node.json
Normal file
12
frontend/tsconfig.node.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"composite": true,
|
||||
"skipLibCheck": true,
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"allowSyntheticDefaultImports": true
|
||||
},
|
||||
"include": [
|
||||
"vite.config.js"
|
||||
]
|
||||
}
|
@ -1,8 +1,14 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
import path from 'path'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
base: '/ui/',
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': path.resolve(__dirname, './src'),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
@ -1,2 +1,3 @@
|
||||
import memory.api.MCP.manifest
|
||||
import memory.api.MCP.memory
|
||||
import memory.api.MCP.metadata
|
||||
|
@ -127,7 +127,6 @@ async def handle_login(request: Request):
|
||||
return login_form(request, oauth_params, "Invalid email or password")
|
||||
|
||||
redirect_url = await oauth_provider.complete_authorization(oauth_params, user)
|
||||
print("redirect_url", redirect_url)
|
||||
if redirect_url.startswith("http://anysphere.cursor-retrieval"):
|
||||
redirect_url = redirect_url.replace("http://", "cursor://")
|
||||
return RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
@ -3,23 +3,25 @@ MCP tools for the epistemic sparring partner system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
import pathlib
|
||||
from datetime import datetime, timezone
|
||||
import mimetypes
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Text, func
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import cast as sql_cast
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.api.search.search import SearchFilters, search
|
||||
from memory.common import extract, settings
|
||||
from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION
|
||||
from memory.common.celery_app import app as celery_app
|
||||
from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import AgentObservation, SourceItem
|
||||
from memory.common.db.models import SourceItem, AgentObservation
|
||||
from memory.common.formatters import observation
|
||||
from memory.common.celery_app import app as celery_app, SYNC_OBSERVATION, SYNC_NOTE
|
||||
from memory.api.MCP.tools import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -47,11 +49,13 @@ def filter_observation_source_ids(
|
||||
return source_ids
|
||||
|
||||
|
||||
def filter_source_ids(
|
||||
modalities: set[str],
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
if not tags:
|
||||
def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] | None:
|
||||
if source_ids := filters.get("source_ids"):
|
||||
return source_ids
|
||||
|
||||
tags = filters.get("tags")
|
||||
size = filters.get("size")
|
||||
if not (tags or size):
|
||||
return None
|
||||
|
||||
with make_session() as session:
|
||||
@ -62,6 +66,8 @@ def filter_source_ids(
|
||||
items_query = items_query.filter(
|
||||
SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))),
|
||||
)
|
||||
if size:
|
||||
items_query = items_query.filter(SourceItem.size == size)
|
||||
if modalities:
|
||||
items_query = items_query.filter(SourceItem.modality.in_(modalities))
|
||||
source_ids = [item.id for item in items_query.all()]
|
||||
@ -69,51 +75,12 @@ def filter_source_ids(
|
||||
return source_ids
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""
|
||||
Get all unique tags used across the entire knowledge base.
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""
|
||||
Get all unique subjects from observations about the user.
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""
|
||||
Get all observation types that have been used.
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_knowledge_base(
|
||||
query: str,
|
||||
previews: bool = False,
|
||||
filters: dict[str, Any],
|
||||
modalities: set[str] = set(),
|
||||
tags: list[str] = [],
|
||||
previews: bool = False,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
@ -125,7 +92,7 @@ async def search_knowledge_base(
|
||||
query: Natural language search query - be descriptive about what you're looking for
|
||||
previews: Include actual content in results - when false only a snippet is returned
|
||||
modalities: Filter by type: email, blog, book, forum, photo, comic, webpage (empty = all)
|
||||
tags: Filter by tags - content must have at least one matching tag
|
||||
filters: Filter by tags, source_ids, etc.
|
||||
limit: Max results (1-100)
|
||||
|
||||
Returns: List of search results with id, score, chunks, content, filename
|
||||
@ -137,6 +104,9 @@ async def search_knowledge_base(
|
||||
modalities = set(ALL_COLLECTIONS.keys())
|
||||
modalities = set(modalities) & ALL_COLLECTIONS.keys() - OBSERVATION_COLLECTIONS
|
||||
|
||||
search_filters = SearchFilters(**filters)
|
||||
search_filters["source_ids"] = filter_source_ids(modalities, search_filters)
|
||||
|
||||
upload_data = extract.extract_text(query)
|
||||
results = await search(
|
||||
upload_data,
|
||||
@ -145,10 +115,7 @@ async def search_knowledge_base(
|
||||
limit=limit,
|
||||
min_text_score=0.4,
|
||||
min_multimodal_score=0.25,
|
||||
filters=SearchFilters(
|
||||
tags=tags,
|
||||
source_ids=filter_source_ids(tags=tags, modalities=modalities),
|
||||
),
|
||||
filters=search_filters,
|
||||
)
|
||||
|
||||
return [result.model_dump() for result in results]
|
||||
|
119
src/memory/api/MCP/metadata.py
Normal file
119
src/memory/api/MCP/metadata.py
Normal file
@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, TypedDict, get_args, get_type_hints
|
||||
|
||||
from memory.common import qdrant
|
||||
from sqlalchemy import func
|
||||
|
||||
from memory.api.MCP.tools import mcp
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import SourceItem
|
||||
from memory.common.db.models.source_items import AgentObservation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchemaArg(TypedDict):
|
||||
type: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
class CollectionMetadata(TypedDict):
|
||||
schema: dict[str, SchemaArg]
|
||||
size: int
|
||||
|
||||
|
||||
def from_annotation(annotation: Annotated) -> SchemaArg | None:
|
||||
try:
|
||||
type_, description = get_args(annotation)
|
||||
type_str = str(type_)
|
||||
if type_str.startswith("typing."):
|
||||
type_str = type_str[7:]
|
||||
elif len((parts := type_str.split("'"))) > 1:
|
||||
type_str = parts[1]
|
||||
return SchemaArg(type=type_str, description=description)
|
||||
except IndexError:
|
||||
logger.error(f"Error from annotation: {annotation}")
|
||||
return None
|
||||
|
||||
|
||||
def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]:
|
||||
if not hasattr(klass, "as_payload"):
|
||||
return {}
|
||||
|
||||
if not (payload_type := get_type_hints(klass.as_payload).get("return")):
|
||||
return {}
|
||||
|
||||
return {
|
||||
name: schema
|
||||
for name, arg in payload_type.__annotations__.items()
|
||||
if (schema := from_annotation(arg))
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_metadata_schemas() -> dict[str, CollectionMetadata]:
|
||||
"""Get the metadata schema for each collection used in the knowledge base.
|
||||
|
||||
These schemas can be used to filter the knowledge base.
|
||||
|
||||
Returns: A mapping of collection names to their metadata schemas with field types and descriptions.
|
||||
|
||||
Example:
|
||||
```
|
||||
{
|
||||
"mail": {"subject": {"type": "str", "description": "The subject of the email."}},
|
||||
"chat": {"subject": {"type": "str", "description": "The subject of the chat message."}}
|
||||
}
|
||||
"""
|
||||
client = qdrant.get_qdrant_client()
|
||||
sizes = qdrant.get_collection_sizes(client)
|
||||
schemas = defaultdict(dict)
|
||||
for klass in SourceItem.__subclasses__():
|
||||
for collection in klass.get_collections():
|
||||
schemas[collection].update(get_schema(klass))
|
||||
|
||||
return {
|
||||
collection: CollectionMetadata(schema=schema, size=size)
|
||||
for collection, schema in schemas.items()
|
||||
if (size := sizes.get(collection))
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_tags() -> list[str]:
|
||||
"""Get all unique tags used across the entire knowledge base.
|
||||
|
||||
Returns sorted list of tags from both observations and content.
|
||||
"""
|
||||
with make_session() as session:
|
||||
tags_query = session.query(func.unnest(SourceItem.tags)).distinct()
|
||||
return sorted({row[0] for row in tags_query if row[0] is not None})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_subjects() -> list[str]:
|
||||
"""Get all unique subjects from observations about the user.
|
||||
|
||||
Returns sorted list of subject identifiers used in observations.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
r.subject for r in session.query(AgentObservation.subject).distinct()
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_all_observation_types() -> list[str]:
|
||||
"""Get all observation types that have been used.
|
||||
|
||||
Standard types are belief, preference, behavior, contradiction, general, but there can be more.
|
||||
"""
|
||||
with make_session() as session:
|
||||
return sorted(
|
||||
{
|
||||
r.observation_type
|
||||
for r in session.query(AgentObservation.observation_type).distinct()
|
||||
if r.observation_type is not None
|
||||
}
|
||||
)
|
@ -71,6 +71,13 @@ async def input_type(item: str | UploadFile) -> list[extract.DataChunk]:
|
||||
# SQLAdmin setup with OAuth protection
|
||||
engine = get_engine()
|
||||
admin = Admin(app, engine)
|
||||
admin.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # [settings.SERVER_URL],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Setup admin with OAuth protection using existing OAuth provider
|
||||
setup_admin(admin)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
import asyncio
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import qdrant_client
|
||||
from PIL import Image
|
||||
@ -54,7 +55,7 @@ def annotated_chunk(
|
||||
)
|
||||
|
||||
|
||||
def query_chunks(
|
||||
async def query_chunks(
|
||||
client: qdrant_client.QdrantClient,
|
||||
upload_data: list[extract.DataChunk],
|
||||
allowed_modalities: set[str],
|
||||
@ -73,22 +74,104 @@ def query_chunks(
|
||||
|
||||
vectors = embedder(chunks, input_type="query")
|
||||
|
||||
return {
|
||||
collection: [
|
||||
r
|
||||
for vector in vectors
|
||||
for r in qdrant.search_vectors(
|
||||
# Create all search tasks to run in parallel
|
||||
search_tasks = []
|
||||
task_metadata = [] # Keep track of which collection and vector each task corresponds to
|
||||
|
||||
for collection in allowed_modalities:
|
||||
for vector in vectors:
|
||||
task = asyncio.to_thread(
|
||||
qdrant.search_vectors,
|
||||
client=client,
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
filter_params=filters,
|
||||
)
|
||||
if r.score >= min_score
|
||||
]
|
||||
for collection in allowed_modalities
|
||||
search_tasks.append(task)
|
||||
task_metadata.append((collection, vector))
|
||||
|
||||
# Run all searches in parallel
|
||||
if not search_tasks:
|
||||
return {}
|
||||
|
||||
search_results = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||
|
||||
# Group results by collection
|
||||
results_by_collection: dict[str, list[qdrant_models.ScoredPoint]] = {
|
||||
collection: [] for collection in allowed_modalities
|
||||
}
|
||||
|
||||
for (collection, _), result in zip(task_metadata, search_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Search failed for collection {collection}: {result}")
|
||||
continue
|
||||
|
||||
# Filter by min_score and add to collection results
|
||||
result_list = cast(list[qdrant_models.ScoredPoint], result)
|
||||
filtered_results = [r for r in result_list if r.score >= min_score]
|
||||
results_by_collection[collection].extend(filtered_results)
|
||||
|
||||
return results_by_collection
|
||||
|
||||
|
||||
def merge_range_filter(
|
||||
filters: list[dict[str, Any]], key: str, val: Any
|
||||
) -> list[dict[str, Any]]:
|
||||
direction, field = key.split("_", maxsplit=1)
|
||||
item = next((f for f in filters if f["key"] == field), None)
|
||||
if not item:
|
||||
item = {"key": field, "range": {}}
|
||||
filters.append(item)
|
||||
|
||||
if direction == "min":
|
||||
item["range"]["gte"] = val
|
||||
elif direction == "max":
|
||||
item["range"]["lte"] = val
|
||||
return filters
|
||||
|
||||
|
||||
def merge_filters(
|
||||
filters: list[dict[str, Any]], key: str, val: Any
|
||||
) -> list[dict[str, Any]]:
|
||||
if not val and val != 0:
|
||||
return filters
|
||||
|
||||
list_filters = ["tags", "recipients", "observation_types", "authors"]
|
||||
range_filters = [
|
||||
"min_sent_at",
|
||||
"max_sent_at",
|
||||
"min_published",
|
||||
"max_published",
|
||||
"min_size",
|
||||
"max_size",
|
||||
"min_created_at",
|
||||
"max_created_at",
|
||||
]
|
||||
if key in list_filters:
|
||||
filters.append({"key": key, "match": {"any": val}})
|
||||
|
||||
elif key in range_filters:
|
||||
return merge_range_filter(filters, key, val)
|
||||
|
||||
elif key == "min_confidences":
|
||||
confidence_filters = [
|
||||
{
|
||||
"key": f"confidence.{confidence_type}",
|
||||
"range": {"gte": min_confidence_score},
|
||||
}
|
||||
for confidence_type, min_confidence_score in cast(dict, val).items()
|
||||
]
|
||||
filters.extend(confidence_filters)
|
||||
|
||||
elif key == "source_ids":
|
||||
filters.append({"key": "id", "match": {"any": val}})
|
||||
|
||||
else:
|
||||
filters.append({"key": key, "match": val})
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
async def search_embeddings(
|
||||
data: list[extract.DataChunk],
|
||||
@ -96,7 +179,7 @@ async def search_embeddings(
|
||||
modalities: set[str] = set(),
|
||||
limit: int = 10,
|
||||
min_score: float = 0.3,
|
||||
filters: SearchFilters = SearchFilters(),
|
||||
filters: SearchFilters = {},
|
||||
multimodal: bool = False,
|
||||
) -> list[tuple[SourceData, AnnotatedChunk]]:
|
||||
"""
|
||||
@ -111,36 +194,19 @@ async def search_embeddings(
|
||||
- filters: Filters to apply to the search results
|
||||
- multimodal: Whether to search in multimodal collections
|
||||
"""
|
||||
query_filters = {"must": []}
|
||||
|
||||
# Handle structured confidence filtering
|
||||
if min_confidences := filters.get("min_confidences"):
|
||||
confidence_filters = [
|
||||
{
|
||||
"key": f"confidence.{confidence_type}",
|
||||
"range": {"gte": min_confidence_score},
|
||||
}
|
||||
for confidence_type, min_confidence_score in min_confidences.items()
|
||||
]
|
||||
query_filters["must"].extend(confidence_filters)
|
||||
|
||||
if tags := filters.get("tags"):
|
||||
query_filters["must"].append({"key": "tags", "match": {"any": tags}})
|
||||
|
||||
if observation_types := filters.get("observation_types"):
|
||||
query_filters["must"].append(
|
||||
{"key": "observation_type", "match": {"any": observation_types}}
|
||||
)
|
||||
search_filters = []
|
||||
for key, val in filters.items():
|
||||
search_filters = merge_filters(search_filters, key, val)
|
||||
|
||||
client = qdrant.get_qdrant_client()
|
||||
results = query_chunks(
|
||||
results = await query_chunks(
|
||||
client,
|
||||
data,
|
||||
modalities,
|
||||
embedding.embed_text if not multimodal else embedding.embed_mixed,
|
||||
min_score=min_score,
|
||||
limit=limit,
|
||||
filters=query_filters if query_filters["must"] else None,
|
||||
filters={"must": search_filters} if search_filters else None,
|
||||
)
|
||||
search_results = {k: results.get(k, []) for k in modalities}
|
||||
|
||||
|
@ -17,6 +17,7 @@ from memory.common.collections import (
|
||||
MULTIMODAL_COLLECTIONS,
|
||||
TEXT_COLLECTIONS,
|
||||
)
|
||||
from memory.common import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -45,50 +46,55 @@ async def search(
|
||||
"""
|
||||
allowed_modalities = modalities & ALL_COLLECTIONS.keys()
|
||||
|
||||
text_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
min_text_score,
|
||||
filters,
|
||||
multimodal=False,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
multimodal_embeddings_results = with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
min_multimodal_score,
|
||||
filters,
|
||||
multimodal=True,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
bm25_results = with_timeout(
|
||||
search_bm25(
|
||||
" ".join([c for chunk in data for c in chunk.data if isinstance(c, str)]),
|
||||
modalities,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
searches = []
|
||||
if settings.ENABLE_EMBEDDING_SEARCH:
|
||||
searches = [
|
||||
with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & TEXT_COLLECTIONS,
|
||||
limit,
|
||||
min_text_score,
|
||||
filters,
|
||||
multimodal=False,
|
||||
),
|
||||
timeout,
|
||||
),
|
||||
with_timeout(
|
||||
search_embeddings(
|
||||
data,
|
||||
previews,
|
||||
allowed_modalities & MULTIMODAL_COLLECTIONS,
|
||||
limit,
|
||||
min_multimodal_score,
|
||||
filters,
|
||||
multimodal=True,
|
||||
),
|
||||
timeout,
|
||||
),
|
||||
]
|
||||
if settings.ENABLE_BM25_SEARCH:
|
||||
searches.append(
|
||||
with_timeout(
|
||||
search_bm25(
|
||||
" ".join(
|
||||
[c for chunk in data for c in chunk.data if isinstance(c, str)]
|
||||
),
|
||||
modalities,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
text_embeddings_results,
|
||||
multimodal_embeddings_results,
|
||||
bm25_results,
|
||||
return_exceptions=False,
|
||||
)
|
||||
text_results, multi_results, bm25_results = results
|
||||
all_results = text_results + multi_results
|
||||
if len(all_results) < limit:
|
||||
all_results += bm25_results
|
||||
search_results = await asyncio.gather(*searches, return_exceptions=False)
|
||||
all_results = []
|
||||
for results in search_results:
|
||||
if len(all_results) >= limit:
|
||||
break
|
||||
all_results.extend(results)
|
||||
|
||||
results = group_chunks(all_results, previews or False)
|
||||
return sorted(results, key=lambda x: max(c.score for c in x.chunks), reverse=True)
|
||||
|
@ -65,9 +65,9 @@ class SearchResult(BaseModel):
|
||||
|
||||
|
||||
class SearchFilters(TypedDict):
|
||||
subject: NotRequired[str | None]
|
||||
min_size: NotRequired[int]
|
||||
max_size: NotRequired[int]
|
||||
min_confidences: NotRequired[dict[str, float]]
|
||||
tags: NotRequired[list[str] | None]
|
||||
observation_types: NotRequired[list[str] | None]
|
||||
source_ids: NotRequired[list[int] | None]
|
||||
|
||||
@ -115,7 +115,6 @@ def group_chunks(
|
||||
if isinstance(contents, dict):
|
||||
tags = contents.pop("tags", [])
|
||||
content = contents.pop("content", None)
|
||||
print(content)
|
||||
else:
|
||||
content = contents
|
||||
contents = {}
|
||||
|
@ -4,6 +4,7 @@ from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
ConfidenceScore,
|
||||
clean_filename,
|
||||
SourceItemPayload,
|
||||
)
|
||||
from memory.common.db.models.source_items import (
|
||||
MailMessage,
|
||||
@ -19,6 +20,14 @@ from memory.common.db.models.source_items import (
|
||||
Photo,
|
||||
MiscDoc,
|
||||
Note,
|
||||
MailMessagePayload,
|
||||
EmailAttachmentPayload,
|
||||
AgentObservationPayload,
|
||||
BlogPostPayload,
|
||||
ComicPayload,
|
||||
BookSectionPayload,
|
||||
NotePayload,
|
||||
ForumPostPayload,
|
||||
)
|
||||
from memory.common.db.models.observations import (
|
||||
ObservationContradiction,
|
||||
@ -40,6 +49,18 @@ from memory.common.db.models.users import (
|
||||
OAuthRefreshToken,
|
||||
)
|
||||
|
||||
Payload = (
|
||||
SourceItemPayload
|
||||
| AgentObservationPayload
|
||||
| NotePayload
|
||||
| BlogPostPayload
|
||||
| ComicPayload
|
||||
| BookSectionPayload
|
||||
| ForumPostPayload
|
||||
| EmailAttachmentPayload
|
||||
| MailMessagePayload
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"Chunk",
|
||||
@ -75,4 +96,6 @@ __all__ = [
|
||||
"OAuthClientInformation",
|
||||
"OAuthState",
|
||||
"OAuthRefreshToken",
|
||||
# Payloads
|
||||
"Payload",
|
||||
]
|
||||
|
@ -4,7 +4,7 @@ Database models for the knowledge base system.
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any, Sequence, cast
|
||||
from typing import Any, Annotated, Sequence, TypedDict, cast
|
||||
import uuid
|
||||
|
||||
from PIL import Image
|
||||
@ -36,6 +36,17 @@ import memory.common.summarizer as summarizer
|
||||
from memory.common.db.models.base import Base
|
||||
|
||||
|
||||
class MetadataSchema(TypedDict):
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
class SourceItemPayload(TypedDict):
|
||||
source_id: Annotated[int, "Unique identifier of the source item"]
|
||||
tags: Annotated[list[str], "List of tags for categorization"]
|
||||
size: Annotated[int | None, "Size of the content in bytes"]
|
||||
|
||||
|
||||
@event.listens_for(Session, "before_flush")
|
||||
def handle_duplicate_sha256(session, flush_context, instances):
|
||||
"""
|
||||
@ -344,12 +355,17 @@ class SourceItem(Base):
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
return [self._make_chunk(data, metadata) for data in self._chunk_contents()]
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
"source_id": self.id,
|
||||
"tags": self.tags,
|
||||
"size": self.size,
|
||||
}
|
||||
def as_payload(self) -> SourceItemPayload:
|
||||
return SourceItemPayload(
|
||||
source_id=cast(int, self.id),
|
||||
tags=cast(list[str], self.tags),
|
||||
size=cast(int | None, self.size),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
"""Return the list of Qdrant collections this SourceItem type can be stored in."""
|
||||
return [cls.__tablename__]
|
||||
|
||||
@property
|
||||
def display_contents(self) -> str | dict | None:
|
||||
|
@ -5,7 +5,7 @@ Database models for the knowledge base system.
|
||||
import pathlib
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, Sequence, cast
|
||||
from typing import Any, Annotated, Sequence, cast
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import (
|
||||
@ -32,11 +32,21 @@ import memory.common.formatters.observation as observation
|
||||
from memory.common.db.models.source_item import (
|
||||
SourceItem,
|
||||
Chunk,
|
||||
SourceItemPayload,
|
||||
clean_filename,
|
||||
chunk_mixed,
|
||||
)
|
||||
|
||||
|
||||
class MailMessagePayload(SourceItemPayload):
|
||||
message_id: Annotated[str, "Unique email message identifier"]
|
||||
subject: Annotated[str, "Email subject line"]
|
||||
sender: Annotated[str, "Email sender address"]
|
||||
recipients: Annotated[list[str], "List of recipient email addresses"]
|
||||
folder: Annotated[str, "Email folder name"]
|
||||
date: Annotated[str | None, "Email sent date in ISO format"]
|
||||
|
||||
|
||||
class MailMessage(SourceItem):
|
||||
__tablename__ = "mail_message"
|
||||
|
||||
@ -80,17 +90,21 @@ class MailMessage(SourceItem):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
**super().as_payload(),
|
||||
"message_id": self.message_id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
"recipients": self.recipients,
|
||||
"folder": self.folder,
|
||||
"tags": self.tags + [self.sender] + self.recipients,
|
||||
"date": (self.sent_at and self.sent_at.isoformat() or None), # type: ignore
|
||||
def as_payload(self) -> MailMessagePayload:
|
||||
base_payload = super().as_payload() | {
|
||||
"tags": cast(list[str], self.tags)
|
||||
+ [cast(str, self.sender)]
|
||||
+ cast(list[str], self.recipients)
|
||||
}
|
||||
return MailMessagePayload(
|
||||
**cast(dict, base_payload),
|
||||
message_id=cast(str, self.message_id),
|
||||
subject=cast(str, self.subject),
|
||||
sender=cast(str, self.sender),
|
||||
recipients=cast(list[str], self.recipients),
|
||||
folder=cast(str, self.folder),
|
||||
date=(self.sent_at and self.sent_at.isoformat() or None), # type: ignore
|
||||
)
|
||||
|
||||
@property
|
||||
def parsed_content(self) -> dict[str, Any]:
|
||||
@ -152,7 +166,7 @@ class MailMessage(SourceItem):
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
content = self.parsed_content
|
||||
chunks = extract.extract_text(cast(str, self.body))
|
||||
chunks = extract.extract_text(cast(str, self.body), modality="mail")
|
||||
|
||||
def add_header(item: extract.MulitmodalChunk) -> extract.MulitmodalChunk:
|
||||
if isinstance(item, str):
|
||||
@ -163,6 +177,10 @@ class MailMessage(SourceItem):
|
||||
chunk.data = [add_header(item) for item in chunk.data]
|
||||
return chunks
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["mail"]
|
||||
|
||||
# Add indexes
|
||||
__table_args__ = (
|
||||
Index("mail_sent_idx", "sent_at"),
|
||||
@ -171,6 +189,14 @@ class MailMessage(SourceItem):
|
||||
)
|
||||
|
||||
|
||||
class EmailAttachmentPayload(SourceItemPayload):
|
||||
filename: Annotated[str, "Name of the document file"]
|
||||
content_type: Annotated[str, "MIME type of the document"]
|
||||
mail_message_id: Annotated[int, "Associated email message ID"]
|
||||
sent_at: Annotated[str | None, "Document creation timestamp"]
|
||||
created_at: Annotated[str | None, "Document creation timestamp"]
|
||||
|
||||
|
||||
class EmailAttachment(SourceItem):
|
||||
__tablename__ = "email_attachment"
|
||||
|
||||
@ -190,17 +216,21 @@ class EmailAttachment(SourceItem):
|
||||
"polymorphic_identity": "email_attachment",
|
||||
}
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
def as_payload(self) -> EmailAttachmentPayload:
|
||||
return EmailAttachmentPayload(
|
||||
**super().as_payload(),
|
||||
"filename": self.filename,
|
||||
"content_type": self.mime_type,
|
||||
"size": self.size,
|
||||
"created_at": (self.created_at and self.created_at.isoformat() or None), # type: ignore
|
||||
"mail_message_id": self.mail_message_id,
|
||||
}
|
||||
created_at=(self.created_at and self.created_at.isoformat() or None), # type: ignore
|
||||
filename=cast(str, self.filename),
|
||||
content_type=cast(str, self.mime_type),
|
||||
mail_message_id=cast(int, self.mail_message_id),
|
||||
sent_at=(
|
||||
self.mail_message.sent_at
|
||||
and self.mail_message.sent_at.isoformat()
|
||||
or None
|
||||
), # type: ignore
|
||||
)
|
||||
|
||||
def data_chunks(self, metadata: dict[str, Any] = {}) -> Sequence[Chunk]:
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
if cast(str | None, self.filename):
|
||||
contents = (
|
||||
settings.FILE_STORAGE_DIR / cast(str, self.filename)
|
||||
@ -208,8 +238,7 @@ class EmailAttachment(SourceItem):
|
||||
else:
|
||||
contents = cast(str, self.content)
|
||||
|
||||
chunks = extract.extract_data_chunks(cast(str, self.mime_type), contents)
|
||||
return [self._make_chunk(c, metadata) for c in chunks]
|
||||
return extract.extract_data_chunks(cast(str, self.mime_type), contents)
|
||||
|
||||
@property
|
||||
def display_contents(self) -> dict:
|
||||
@ -221,6 +250,11 @@ class EmailAttachment(SourceItem):
|
||||
# Add indexes
|
||||
__table_args__ = (Index("email_attachment_message_idx", "mail_message_id"),)
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
"""EmailAttachment can go to different collections based on mime_type"""
|
||||
return ["doc", "text", "blog", "photo", "book"]
|
||||
|
||||
|
||||
class ChatMessage(SourceItem):
|
||||
__tablename__ = "chat_message"
|
||||
@ -285,6 +319,16 @@ class Photo(SourceItem):
|
||||
__table_args__ = (Index("photo_taken_idx", "exif_taken_at"),)
|
||||
|
||||
|
||||
class ComicPayload(SourceItemPayload):
|
||||
title: Annotated[str, "Title of the comic"]
|
||||
author: Annotated[str | None, "Author of the comic"]
|
||||
published: Annotated[str | None, "Publication date in ISO format"]
|
||||
volume: Annotated[str | None, "Volume number"]
|
||||
issue: Annotated[str | None, "Issue number"]
|
||||
page: Annotated[int | None, "Page number"]
|
||||
url: Annotated[str | None, "URL of the comic"]
|
||||
|
||||
|
||||
class Comic(SourceItem):
|
||||
__tablename__ = "comic"
|
||||
|
||||
@ -305,18 +349,17 @@ class Comic(SourceItem):
|
||||
|
||||
__table_args__ = (Index("comic_author_idx", "author"),)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
payload = {
|
||||
def as_payload(self) -> ComicPayload:
|
||||
return ComicPayload(
|
||||
**super().as_payload(),
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"published": self.published,
|
||||
"volume": self.volume,
|
||||
"issue": self.issue,
|
||||
"page": self.page,
|
||||
"url": self.url,
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v is not None}
|
||||
title=cast(str, self.title),
|
||||
author=cast(str | None, self.author),
|
||||
published=(self.published and self.published.isoformat() or None), # type: ignore
|
||||
volume=cast(str | None, self.volume),
|
||||
issue=cast(str | None, self.issue),
|
||||
page=cast(int | None, self.page),
|
||||
url=cast(str | None, self.url),
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
image = Image.open(settings.FILE_STORAGE_DIR / cast(str, self.filename))
|
||||
@ -324,6 +367,17 @@ class Comic(SourceItem):
|
||||
return [extract.DataChunk(data=[image, description])]
|
||||
|
||||
|
||||
class BookSectionPayload(SourceItemPayload):
|
||||
title: Annotated[str, "Title of the book"]
|
||||
author: Annotated[str | None, "Author of the book"]
|
||||
book_id: Annotated[int, "Unique identifier of the book"]
|
||||
section_title: Annotated[str, "Title of the section"]
|
||||
section_number: Annotated[int, "Number of the section"]
|
||||
section_level: Annotated[int, "Level of the section"]
|
||||
start_page: Annotated[int, "Starting page number"]
|
||||
end_page: Annotated[int, "Ending page number"]
|
||||
|
||||
|
||||
class BookSection(SourceItem):
|
||||
"""Individual sections/chapters of books"""
|
||||
|
||||
@ -361,19 +415,22 @@ class BookSection(SourceItem):
|
||||
Index("book_section_level_idx", "section_level", "section_number"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
vals = {
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["book"]
|
||||
|
||||
def as_payload(self) -> BookSectionPayload:
|
||||
return BookSectionPayload(
|
||||
**super().as_payload(),
|
||||
"title": self.book.title,
|
||||
"author": self.book.author,
|
||||
"book_id": self.book_id,
|
||||
"section_title": self.section_title,
|
||||
"section_number": self.section_number,
|
||||
"section_level": self.section_level,
|
||||
"start_page": self.start_page,
|
||||
"end_page": self.end_page,
|
||||
}
|
||||
return {k: v for k, v in vals.items() if v}
|
||||
title=cast(str, self.book.title),
|
||||
author=cast(str | None, self.book.author),
|
||||
book_id=cast(int, self.book_id),
|
||||
section_title=cast(str, self.section_title),
|
||||
section_number=cast(int, self.section_number),
|
||||
section_level=cast(int, self.section_level),
|
||||
start_page=cast(int, self.start_page),
|
||||
end_page=cast(int, self.end_page),
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
content = cast(str, self.content.strip())
|
||||
@ -397,6 +454,16 @@ class BookSection(SourceItem):
|
||||
]
|
||||
|
||||
|
||||
class BlogPostPayload(SourceItemPayload):
|
||||
url: Annotated[str, "URL of the blog post"]
|
||||
title: Annotated[str, "Title of the blog post"]
|
||||
author: Annotated[str | None, "Author of the blog post"]
|
||||
published: Annotated[str | None, "Publication date in ISO format"]
|
||||
description: Annotated[str | None, "Description of the blog post"]
|
||||
domain: Annotated[str | None, "Domain of the blog post"]
|
||||
word_count: Annotated[int | None, "Word count of the blog post"]
|
||||
|
||||
|
||||
class BlogPost(SourceItem):
|
||||
__tablename__ = "blog_post"
|
||||
|
||||
@ -428,27 +495,39 @@ class BlogPost(SourceItem):
|
||||
Index("blog_post_word_count_idx", "word_count"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
def as_payload(self) -> BlogPostPayload:
|
||||
published_date = cast(datetime | None, self.published)
|
||||
metadata = cast(dict | None, self.webpage_metadata) or {}
|
||||
|
||||
payload = {
|
||||
return BlogPostPayload(
|
||||
**super().as_payload(),
|
||||
"url": self.url,
|
||||
"title": self.title,
|
||||
"author": self.author,
|
||||
"published": published_date and published_date.isoformat(),
|
||||
"description": self.description,
|
||||
"domain": self.domain,
|
||||
"word_count": self.word_count,
|
||||
url=cast(str, self.url),
|
||||
title=cast(str, self.title),
|
||||
author=cast(str | None, self.author),
|
||||
published=(published_date and published_date.isoformat() or None), # type: ignore
|
||||
description=cast(str | None, self.description),
|
||||
domain=cast(str | None, self.domain),
|
||||
word_count=cast(int | None, self.word_count),
|
||||
**metadata,
|
||||
}
|
||||
return {k: v for k, v in payload.items() if v}
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||
|
||||
|
||||
class ForumPostPayload(SourceItemPayload):
|
||||
url: Annotated[str, "URL of the forum post"]
|
||||
title: Annotated[str, "Title of the forum post"]
|
||||
description: Annotated[str | None, "Description of the forum post"]
|
||||
authors: Annotated[list[str] | None, "Authors of the forum post"]
|
||||
published: Annotated[str | None, "Publication date in ISO format"]
|
||||
slug: Annotated[str | None, "Slug of the forum post"]
|
||||
karma: Annotated[int | None, "Karma score of the forum post"]
|
||||
votes: Annotated[int | None, "Number of votes on the forum post"]
|
||||
score: Annotated[int | None, "Score of the forum post"]
|
||||
comments: Annotated[int | None, "Number of comments on the forum post"]
|
||||
|
||||
|
||||
class ForumPost(SourceItem):
|
||||
__tablename__ = "forum_post"
|
||||
|
||||
@ -479,20 +558,20 @@ class ForumPost(SourceItem):
|
||||
Index("forum_post_title_idx", "title"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
def as_payload(self) -> ForumPostPayload:
|
||||
return ForumPostPayload(
|
||||
**super().as_payload(),
|
||||
"url": self.url,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"authors": self.authors,
|
||||
"published_at": self.published_at,
|
||||
"slug": self.slug,
|
||||
"karma": self.karma,
|
||||
"votes": self.votes,
|
||||
"score": self.score,
|
||||
"comments": self.comments,
|
||||
}
|
||||
url=cast(str, self.url),
|
||||
title=cast(str, self.title),
|
||||
description=cast(str | None, self.description),
|
||||
authors=cast(list[str] | None, self.authors),
|
||||
published=(self.published_at and self.published_at.isoformat() or None), # type: ignore
|
||||
slug=cast(str | None, self.slug),
|
||||
karma=cast(int | None, self.karma),
|
||||
votes=cast(int | None, self.votes),
|
||||
score=cast(int | None, self.score),
|
||||
comments=cast(int | None, self.comments),
|
||||
)
|
||||
|
||||
def _chunk_contents(self) -> Sequence[extract.DataChunk]:
|
||||
return chunk_mixed(cast(str, self.content), cast(list[str], self.images))
|
||||
@ -545,6 +624,12 @@ class GithubItem(SourceItem):
|
||||
)
|
||||
|
||||
|
||||
class NotePayload(SourceItemPayload):
|
||||
note_type: Annotated[str | None, "Category of the note"]
|
||||
subject: Annotated[str | None, "What the note is about"]
|
||||
confidence: Annotated[dict[str, float], "Confidence scores for the note"]
|
||||
|
||||
|
||||
class Note(SourceItem):
|
||||
"""A quick note of something of interest."""
|
||||
|
||||
@ -565,13 +650,13 @@ class Note(SourceItem):
|
||||
Index("note_subject_idx", "subject"),
|
||||
)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
return {
|
||||
def as_payload(self) -> NotePayload:
|
||||
return NotePayload(
|
||||
**super().as_payload(),
|
||||
"note_type": self.note_type,
|
||||
"subject": self.subject,
|
||||
"confidence": self.confidence_dict,
|
||||
}
|
||||
note_type=cast(str | None, self.note_type),
|
||||
subject=cast(str | None, self.subject),
|
||||
confidence=self.confidence_dict,
|
||||
)
|
||||
|
||||
@property
|
||||
def display_contents(self) -> dict:
|
||||
@ -602,6 +687,19 @@ class Note(SourceItem):
|
||||
self.as_text(cast(str, self.content), cast(str | None, self.subject))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["text"] # Notes go to the text collection
|
||||
|
||||
|
||||
class AgentObservationPayload(SourceItemPayload):
|
||||
session_id: Annotated[str | None, "Session ID for the observation"]
|
||||
observation_type: Annotated[str, "Type of observation"]
|
||||
subject: Annotated[str, "What/who the observation is about"]
|
||||
confidence: Annotated[dict[str, float], "Confidence scores for the observation"]
|
||||
evidence: Annotated[dict | None, "Supporting context, quotes, etc."]
|
||||
agent_model: Annotated[str, "Which AI model made this observation"]
|
||||
|
||||
|
||||
class AgentObservation(SourceItem):
|
||||
"""
|
||||
@ -652,18 +750,16 @@ class AgentObservation(SourceItem):
|
||||
kwargs["modality"] = "observation"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def as_payload(self) -> dict:
|
||||
payload = {
|
||||
def as_payload(self) -> AgentObservationPayload:
|
||||
return AgentObservationPayload(
|
||||
**super().as_payload(),
|
||||
"observation_type": self.observation_type,
|
||||
"subject": self.subject,
|
||||
"confidence": self.confidence_dict,
|
||||
"evidence": self.evidence,
|
||||
"agent_model": self.agent_model,
|
||||
}
|
||||
if self.session_id is not None:
|
||||
payload["session_id"] = str(self.session_id)
|
||||
return payload
|
||||
observation_type=cast(str, self.observation_type),
|
||||
subject=cast(str, self.subject),
|
||||
confidence=self.confidence_dict,
|
||||
evidence=cast(dict | None, self.evidence),
|
||||
agent_model=cast(str, self.agent_model),
|
||||
session_id=cast(str | None, self.session_id) and str(self.session_id),
|
||||
)
|
||||
|
||||
@property
|
||||
def all_contradictions(self):
|
||||
@ -759,3 +855,7 @@ class AgentObservation(SourceItem):
|
||||
# ))
|
||||
|
||||
return chunks
|
||||
|
||||
@classmethod
|
||||
def get_collections(cls) -> list[str]:
|
||||
return ["semantic", "temporal"]
|
||||
|
@ -53,7 +53,9 @@ def page_to_image(page: pymupdf.Page) -> Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
def doc_to_images(
|
||||
content: bytes | str | pathlib.Path, modality: str = "doc"
|
||||
) -> list[DataChunk]:
|
||||
with as_file(content) as file_path:
|
||||
with pymupdf.open(file_path) as pdf:
|
||||
return [
|
||||
@ -65,6 +67,7 @@ def doc_to_images(content: bytes | str | pathlib.Path) -> list[DataChunk]:
|
||||
"height": page.rect.height,
|
||||
},
|
||||
mime_type="image/jpeg",
|
||||
modality=modality,
|
||||
)
|
||||
for page in pdf.pages()
|
||||
]
|
||||
@ -122,6 +125,7 @@ def extract_text(
|
||||
content: bytes | str | pathlib.Path,
|
||||
chunk_size: int | None = None,
|
||||
metadata: dict[str, Any] = {},
|
||||
modality: str = "text",
|
||||
) -> list[DataChunk]:
|
||||
if isinstance(content, pathlib.Path):
|
||||
content = content.read_text()
|
||||
@ -130,7 +134,7 @@ def extract_text(
|
||||
|
||||
content = cast(str, content)
|
||||
chunks = [
|
||||
DataChunk(data=[c], modality="text", metadata=metadata)
|
||||
DataChunk(data=[c], modality=modality, metadata=metadata)
|
||||
for c in chunker.chunk_text(content, chunk_size or chunker.DEFAULT_CHUNK_TOKENS)
|
||||
]
|
||||
if content and len(content) > chunker.DEFAULT_CHUNK_TOKENS * 2:
|
||||
@ -139,7 +143,7 @@ def extract_text(
|
||||
DataChunk(
|
||||
data=[summary],
|
||||
metadata=merge_metadata(metadata, {"tags": tags}),
|
||||
modality="text",
|
||||
modality=modality,
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
@ -158,9 +162,7 @@ def extract_data_chunks(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/msword",
|
||||
]:
|
||||
logger.info(f"Extracting content from {content}")
|
||||
chunks = extract_docx(content)
|
||||
logger.info(f"Extracted {len(chunks)} pages from {content}")
|
||||
elif mime_type.startswith("text/"):
|
||||
chunks = extract_text(content, chunk_size)
|
||||
elif mime_type.startswith("image/"):
|
||||
|
@ -224,6 +224,15 @@ def get_collection_info(
|
||||
return info.model_dump()
|
||||
|
||||
|
||||
def get_collection_sizes(client: qdrant_client.QdrantClient) -> dict[str, int]:
|
||||
"""Get the size of each collection."""
|
||||
collections = [i.name for i in client.get_collections().collections]
|
||||
return {
|
||||
collection_name: client.count(collection_name).count # type: ignore
|
||||
for collection_name in collections
|
||||
}
|
||||
|
||||
|
||||
def batch_ids(
|
||||
client: qdrant_client.QdrantClient, collection_name: str, batch_size: int = 1000
|
||||
) -> Generator[list[str], None, None]:
|
||||
|
@ -6,6 +6,8 @@ load_dotenv()
|
||||
|
||||
|
||||
def boolean_env(key: str, default: bool = False) -> bool:
|
||||
if key not in os.environ:
|
||||
return default
|
||||
return os.getenv(key, "0").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
@ -130,6 +132,10 @@ else:
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
SUMMARIZER_MODEL = os.getenv("SUMMARIZER_MODEL", "anthropic/claude-3-haiku-20240307")
|
||||
|
||||
# Search settings
|
||||
ENABLE_EMBEDDING_SEARCH = boolean_env("ENABLE_EMBEDDING_SEARCH", True)
|
||||
ENABLE_BM25_SEARCH = boolean_env("ENABLE_BM25_SEARCH", True)
|
||||
|
||||
# API settings
|
||||
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
||||
HTTPS = boolean_env("HTTPS", False)
|
||||
|
@ -95,93 +95,93 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What does the user think about functional programming?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.7104,
|
||||
0.71,
|
||||
"The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(0.6788, "I prefer functional programming over OOP"),
|
||||
(0.679, "I prefer functional programming over OOP"),
|
||||
(
|
||||
0.6759,
|
||||
0.676,
|
||||
"Subject: programming_philosophy | Type: belief | Observation: The user believes functional programming leads to better code quality | Quote: Functional programming produces more maintainable code",
|
||||
),
|
||||
(
|
||||
0.6678,
|
||||
0.668,
|
||||
"Subject: programming_paradigms | Type: preference | Observation: The user prefers functional programming over OOP | Quote: I prefer functional programming over OOP",
|
||||
),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5971,
|
||||
0.597,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.5308,
|
||||
0.531,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_paradigms | Observation: The user prefers functional programming over OOP",
|
||||
),
|
||||
(
|
||||
0.5167,
|
||||
0.517,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: pure_functions | Observation: The user said pure functions are yucky",
|
||||
),
|
||||
(
|
||||
0.4702,
|
||||
0.47,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: refactoring | Observation: The user always refactors to pure functions",
|
||||
),
|
||||
],
|
||||
},
|
||||
"Does the user prefer functional or object-oriented programming?": {
|
||||
"semantic": [
|
||||
(0.7719, "The user prefers functional programming over OOP"),
|
||||
(0.772, "The user prefers functional programming over OOP"),
|
||||
(
|
||||
0.7541,
|
||||
0.754,
|
||||
"Subject: programming_paradigms | Type: preference | Observation: The user prefers functional programming over OOP | Quote: I prefer functional programming over OOP",
|
||||
),
|
||||
(0.7455, "I prefer functional programming over OOP"),
|
||||
(0.745, "I prefer functional programming over OOP"),
|
||||
(
|
||||
0.6536,
|
||||
0.654,
|
||||
"The user believes functional programming leads to better code quality",
|
||||
),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.6251,
|
||||
0.625,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_paradigms | Observation: The user prefers functional programming over OOP",
|
||||
),
|
||||
(
|
||||
0.6062,
|
||||
0.606,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.5061,
|
||||
0.506,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: pure_functions | Observation: The user said pure functions are yucky",
|
||||
),
|
||||
(
|
||||
0.5036,
|
||||
0.504,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: refactoring | Observation: The user always refactors to pure functions",
|
||||
),
|
||||
],
|
||||
},
|
||||
"What are the user's beliefs about code quality?": {
|
||||
"semantic": [
|
||||
(0.6925, "The user believes code reviews are essential for quality"),
|
||||
(0.692, "The user believes code reviews are essential for quality"),
|
||||
(
|
||||
0.6801,
|
||||
0.68,
|
||||
"The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.6525,
|
||||
0.652,
|
||||
"Subject: code_quality | Type: belief | Observation: The user believes code reviews are essential for quality | Quote: Code reviews catch bugs that automated testing misses",
|
||||
),
|
||||
(
|
||||
0.6471,
|
||||
0.647,
|
||||
"Subject: programming_philosophy | Type: belief | Observation: The user believes functional programming leads to better code quality | Quote: Functional programming produces more maintainable code",
|
||||
),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5269,
|
||||
0.527,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.5193,
|
||||
0.519,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: code_quality | Observation: The user believes code reviews are essential for quality",
|
||||
),
|
||||
(
|
||||
@ -189,7 +189,7 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: testing_philosophy | Observation: The user believes unit tests are a waste of time for prototypes",
|
||||
),
|
||||
(
|
||||
0.4377,
|
||||
0.438,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: pure_functions | Observation: The user said pure functions are yucky",
|
||||
),
|
||||
],
|
||||
@ -197,22 +197,22 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"How does the user approach debugging code?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.7007,
|
||||
0.701,
|
||||
"Subject: debugging_approach | Type: behavior | Observation: The user debugs by adding print statements rather than using a debugger | Quote: When debugging, I just add console.log everywhere",
|
||||
),
|
||||
(
|
||||
0.6956,
|
||||
0.696,
|
||||
"The user debugs by adding print statements rather than using a debugger",
|
||||
),
|
||||
(0.6795, "When debugging, I just add console.log everywhere"),
|
||||
(0.68, "When debugging, I just add console.log everywhere"),
|
||||
(
|
||||
0.5352,
|
||||
0.535,
|
||||
"Subject: code_quality | Type: belief | Observation: The user believes code reviews are essential for quality | Quote: Code reviews catch bugs that automated testing misses",
|
||||
),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.6253,
|
||||
0.625,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: debugging_approach | Observation: The user debugs by adding print statements rather than using a debugger",
|
||||
),
|
||||
(
|
||||
@ -220,11 +220,11 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: indentation_preference | Observation: The user claims to prefer tabs but their code uses spaces",
|
||||
),
|
||||
(
|
||||
0.4589,
|
||||
0.459,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: testing_philosophy | Observation: The user believes unit tests are a waste of time for prototypes",
|
||||
),
|
||||
(
|
||||
0.4502,
|
||||
0.45,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
|
||||
),
|
||||
],
|
||||
@ -232,63 +232,63 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What are the user's git and version control habits?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.6485,
|
||||
0.648,
|
||||
"Subject: version_control_style | Type: preference | Observation: The user prefers small, focused commits over large feature branches | Quote: I like to commit small, logical changes frequently",
|
||||
),
|
||||
(0.643, "I like to commit small, logical changes frequently"),
|
||||
(
|
||||
0.5968,
|
||||
0.597,
|
||||
"The user prefers small, focused commits over large feature branches",
|
||||
),
|
||||
(
|
||||
0.5813,
|
||||
0.581,
|
||||
"Subject: git_habits | Type: behavior | Observation: The user writes commit messages in present tense | Quote: Fix bug in parser instead of Fixed bug in parser",
|
||||
),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.6063,
|
||||
0.606,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
|
||||
),
|
||||
(
|
||||
0.5569,
|
||||
0.557,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: git_habits | Observation: The user writes commit messages in present tense",
|
||||
),
|
||||
(
|
||||
0.4806,
|
||||
0.481,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: editor_preference | Observation: The user prefers Vim over VS Code for editing",
|
||||
),
|
||||
(
|
||||
0.4622,
|
||||
0.462,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: code_quality | Observation: The user believes code reviews are essential for quality",
|
||||
),
|
||||
],
|
||||
},
|
||||
"When does the user prefer to work?": {
|
||||
"semantic": [
|
||||
(0.6805, "The user prefers working late at night"),
|
||||
(0.681, "The user prefers working late at night"),
|
||||
(
|
||||
0.6794,
|
||||
0.679,
|
||||
"Subject: work_schedule | Type: behavior | Observation: The user prefers working late at night | Quote: I do my best coding between 10pm and 2am",
|
||||
),
|
||||
(0.6432, "I do my best coding between 10pm and 2am"),
|
||||
(0.5525, "I use 25-minute work intervals with 5-minute breaks"),
|
||||
(0.643, "I do my best coding between 10pm and 2am"),
|
||||
(0.553, "I use 25-minute work intervals with 5-minute breaks"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.6896,
|
||||
0.69,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: work_schedule | Observation: The user prefers working late at night",
|
||||
),
|
||||
(
|
||||
0.6327,
|
||||
0.633,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
(
|
||||
0.6266,
|
||||
0.627,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: work_environment | Observation: The user thinks remote work is more productive than office work",
|
||||
),
|
||||
(
|
||||
0.6206,
|
||||
0.621,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: collaboration_preference | Observation: The user prefers pair programming for complex problems",
|
||||
),
|
||||
],
|
||||
@ -296,31 +296,31 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"How does the user handle productivity and time management?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.5795,
|
||||
0.579,
|
||||
"Subject: productivity_methods | Type: behavior | Observation: The user takes breaks every 25 minutes using the Pomodoro technique | Quote: I use 25-minute work intervals with 5-minute breaks",
|
||||
),
|
||||
(0.5727, "I use 25-minute work intervals with 5-minute breaks"),
|
||||
(0.572, "I use 25-minute work intervals with 5-minute breaks"),
|
||||
(
|
||||
0.5282,
|
||||
0.527,
|
||||
"The user takes breaks every 25 minutes using the Pomodoro technique",
|
||||
),
|
||||
(0.515, "I do my best coding between 10pm and 2am"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5633,
|
||||
0.563,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: productivity_methods | Observation: The user takes breaks every 25 minutes using the Pomodoro technique",
|
||||
),
|
||||
(
|
||||
0.5105,
|
||||
0.51,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: work_environment | Observation: The user thinks remote work is more productive than office work",
|
||||
),
|
||||
(
|
||||
0.4737,
|
||||
0.473,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: documentation_habits | Observation: The user always writes documentation before implementing features",
|
||||
),
|
||||
(
|
||||
0.4672,
|
||||
0.467,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: work_schedule | Observation: The user prefers working late at night",
|
||||
),
|
||||
],
|
||||
@ -328,28 +328,28 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What editor does the user prefer?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.6398,
|
||||
0.64,
|
||||
"Subject: editor_preference | Type: preference | Observation: The user prefers Vim over VS Code for editing | Quote: Vim makes me more productive than any modern editor",
|
||||
),
|
||||
(0.6242, "The user prefers Vim over VS Code for editing"),
|
||||
(0.5524, "Vim makes me more productive than any modern editor"),
|
||||
(0.4887, "The user claims to prefer tabs but their code uses spaces"),
|
||||
(0.624, "The user prefers Vim over VS Code for editing"),
|
||||
(0.552, "Vim makes me more productive than any modern editor"),
|
||||
(0.489, "The user claims to prefer tabs but their code uses spaces"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5626,
|
||||
0.563,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: editor_preference | Observation: The user prefers Vim over VS Code for editing",
|
||||
),
|
||||
(
|
||||
0.4507,
|
||||
0.451,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: indentation_preference | Observation: The user claims to prefer tabs but their code uses spaces",
|
||||
),
|
||||
(
|
||||
0.4333,
|
||||
0.433,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: database_preference | Observation: The user prefers PostgreSQL over MongoDB for most applications",
|
||||
),
|
||||
(
|
||||
0.4307,
|
||||
0.431,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
],
|
||||
@ -357,27 +357,27 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What databases does the user like to use?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.6328,
|
||||
0.633,
|
||||
"Subject: database_preference | Type: preference | Observation: The user prefers PostgreSQL over MongoDB for most applications | Quote: Relational databases handle complex queries better than document stores",
|
||||
),
|
||||
(0.5991, "The user prefers PostgreSQL over MongoDB for most applications"),
|
||||
(0.599, "The user prefers PostgreSQL over MongoDB for most applications"),
|
||||
(
|
||||
0.5357,
|
||||
0.536,
|
||||
"Subject: domain_preference | Type: preference | Observation: The user prefers working on backend systems over frontend UI | Quote: I find backend logic more interesting than UI work",
|
||||
),
|
||||
(0.5178, "The user prefers working on backend systems over frontend UI"),
|
||||
(0.518, "The user prefers working on backend systems over frontend UI"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5503,
|
||||
0.55,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: database_preference | Observation: The user prefers PostgreSQL over MongoDB for most applications",
|
||||
),
|
||||
(
|
||||
0.4583,
|
||||
0.458,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
(
|
||||
0.4445,
|
||||
0.445,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: primary_languages | Observation: The user primarily works with Python and JavaScript",
|
||||
),
|
||||
(
|
||||
@ -388,21 +388,21 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
},
|
||||
"What programming languages does the user work with?": {
|
||||
"semantic": [
|
||||
(0.7264, "The user primarily works with Python and JavaScript"),
|
||||
(0.6958, "Most of my work is in Python backend and React frontend"),
|
||||
(0.726, "The user primarily works with Python and JavaScript"),
|
||||
(0.696, "Most of my work is in Python backend and React frontend"),
|
||||
(
|
||||
0.6875,
|
||||
0.688,
|
||||
"Subject: primary_languages | Type: general | Observation: The user primarily works with Python and JavaScript | Quote: Most of my work is in Python backend and React frontend",
|
||||
),
|
||||
(0.6111, "I'm picking up Rust on weekends"),
|
||||
(0.611, "I'm picking up Rust on weekends"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5774,
|
||||
0.577,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: primary_languages | Observation: The user primarily works with Python and JavaScript",
|
||||
),
|
||||
(
|
||||
0.4692,
|
||||
0.469,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: experience_level | Observation: The user has 8 years of professional programming experience",
|
||||
),
|
||||
(
|
||||
@ -410,36 +410,36 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.4475,
|
||||
0.447,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: learning_activities | Observation: The user is currently learning Rust in their spare time",
|
||||
),
|
||||
],
|
||||
},
|
||||
"What is the user's programming experience level?": {
|
||||
"semantic": [
|
||||
(0.6663, "The user has 8 years of professional programming experience"),
|
||||
(0.666, "The user has 8 years of professional programming experience"),
|
||||
(
|
||||
0.6562,
|
||||
0.656,
|
||||
"Subject: experience_level | Type: general | Observation: The user has 8 years of professional programming experience | Quote: I've been coding professionally for 8 years",
|
||||
),
|
||||
(0.5952, "I've been coding professionally for 8 years"),
|
||||
(0.5656, "The user is currently learning Rust in their spare time"),
|
||||
(0.595, "I've been coding professionally for 8 years"),
|
||||
(0.566, "The user is currently learning Rust in their spare time"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5808,
|
||||
0.581,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: experience_level | Observation: The user has 8 years of professional programming experience",
|
||||
),
|
||||
(
|
||||
0.4814,
|
||||
0.481,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: primary_languages | Observation: The user primarily works with Python and JavaScript",
|
||||
),
|
||||
(
|
||||
0.4752,
|
||||
0.475,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.4591,
|
||||
0.459,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_paradigms | Observation: The user prefers functional programming over OOP",
|
||||
),
|
||||
],
|
||||
@ -447,57 +447,57 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Where did the user study computer science?": {
|
||||
"semantic": [
|
||||
(0.686, "I studied CS at Stanford"),
|
||||
(0.6484, "The user graduated with a Computer Science degree from Stanford"),
|
||||
(0.648, "The user graduated with a Computer Science degree from Stanford"),
|
||||
(
|
||||
0.6346,
|
||||
0.635,
|
||||
"Subject: education_background | Type: general | Observation: The user graduated with a Computer Science degree from Stanford | Quote: I studied CS at Stanford",
|
||||
),
|
||||
(0.4599, "The user is currently learning Rust in their spare time"),
|
||||
(0.46, "The user is currently learning Rust in their spare time"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5288,
|
||||
0.529,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: education_background | Observation: The user graduated with a Computer Science degree from Stanford",
|
||||
),
|
||||
(
|
||||
0.3833,
|
||||
0.383,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: experience_level | Observation: The user has 8 years of professional programming experience",
|
||||
),
|
||||
(
|
||||
0.3728,
|
||||
0.373,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: primary_languages | Observation: The user primarily works with Python and JavaScript",
|
||||
),
|
||||
(
|
||||
0.3651,
|
||||
0.365,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: learning_activities | Observation: The user is currently learning Rust in their spare time",
|
||||
),
|
||||
],
|
||||
},
|
||||
"What kind of company does the user work at?": {
|
||||
"semantic": [
|
||||
(0.6304, "The user works at a mid-size startup with 50 employees"),
|
||||
(0.63, "The user works at a mid-size startup with 50 employees"),
|
||||
(
|
||||
0.5369,
|
||||
0.537,
|
||||
"Subject: company_size | Type: general | Observation: The user works at a mid-size startup with 50 employees | Quote: Our company has about 50 people",
|
||||
),
|
||||
(0.5258, "Most of my work is in Python backend and React frontend"),
|
||||
(0.4905, "I've been coding professionally for 8 years"),
|
||||
(0.526, "Most of my work is in Python backend and React frontend"),
|
||||
(0.49, "I've been coding professionally for 8 years"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5194,
|
||||
0.519,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: company_size | Observation: The user works at a mid-size startup with 50 employees",
|
||||
),
|
||||
(
|
||||
0.4149,
|
||||
0.415,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: work_environment | Observation: The user thinks remote work is more productive than office work",
|
||||
),
|
||||
(
|
||||
0.4144,
|
||||
0.414,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: education_background | Observation: The user graduated with a Computer Science degree from Stanford",
|
||||
),
|
||||
(
|
||||
0.4053,
|
||||
0.405,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: experience_level | Observation: The user has 8 years of professional programming experience",
|
||||
),
|
||||
],
|
||||
@ -505,34 +505,34 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What does the user think about AI replacing programmers?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.5955,
|
||||
0.596,
|
||||
"Subject: ai_future | Type: belief | Observation: The user thinks AI will replace most software developers within 10 years | Quote: AI will make most programmers obsolete by 2035",
|
||||
),
|
||||
(0.5725, "AI will make most programmers obsolete by 2035"),
|
||||
(0.572, "AI will make most programmers obsolete by 2035"),
|
||||
(
|
||||
0.572,
|
||||
"The user thinks AI will replace most software developers within 10 years",
|
||||
),
|
||||
(
|
||||
0.4342,
|
||||
0.434,
|
||||
"The user believes functional programming leads to better code quality",
|
||||
),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.4546,
|
||||
0.455,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: ai_future | Observation: The user thinks AI will replace most software developers within 10 years",
|
||||
),
|
||||
(
|
||||
0.3583,
|
||||
0.358,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.3264,
|
||||
0.326,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: typescript_opinion | Observation: The user now says they love TypeScript but previously called it verbose",
|
||||
),
|
||||
(
|
||||
0.3257,
|
||||
0.326,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: testing_philosophy | Observation: The user believes unit tests are a waste of time for prototypes",
|
||||
),
|
||||
],
|
||||
@ -540,31 +540,31 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What are the user's views on artificial intelligence?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.5884,
|
||||
0.588,
|
||||
"Subject: ai_future | Type: belief | Observation: The user thinks AI will replace most software developers within 10 years | Quote: AI will make most programmers obsolete by 2035",
|
||||
),
|
||||
(
|
||||
0.5659,
|
||||
0.566,
|
||||
"The user thinks AI will replace most software developers within 10 years",
|
||||
),
|
||||
(0.5139, "AI will make most programmers obsolete by 2035"),
|
||||
(0.4927, "I find backend logic more interesting than UI work"),
|
||||
(0.514, "AI will make most programmers obsolete by 2035"),
|
||||
(0.493, "I find backend logic more interesting than UI work"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5205,
|
||||
0.521,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: ai_future | Observation: The user thinks AI will replace most software developers within 10 years",
|
||||
),
|
||||
(
|
||||
0.4203,
|
||||
0.42,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.4007,
|
||||
0.401,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: pure_functions | Observation: The user said pure functions are yucky",
|
||||
),
|
||||
(
|
||||
0.4001,
|
||||
0.4,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: humans | Observation: The user thinks that all men must die.",
|
||||
),
|
||||
],
|
||||
@ -572,34 +572,34 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Has the user changed their mind about TypeScript?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.6166,
|
||||
0.617,
|
||||
"The user now says they love TypeScript but previously called it verbose",
|
||||
),
|
||||
(
|
||||
0.5764,
|
||||
0.576,
|
||||
"Subject: typescript_opinion | Type: contradiction | Observation: The user now says they love TypeScript but previously called it verbose | Quote: TypeScript has too much boilerplate vs TypeScript makes my code so much cleaner",
|
||||
),
|
||||
(
|
||||
0.4907,
|
||||
0.491,
|
||||
"TypeScript has too much boilerplate vs TypeScript makes my code so much cleaner",
|
||||
),
|
||||
(0.4159, "The user always refactors to pure functions"),
|
||||
(0.416, "The user always refactors to pure functions"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5663,
|
||||
0.566,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: typescript_opinion | Observation: The user now says they love TypeScript but previously called it verbose",
|
||||
),
|
||||
(
|
||||
0.3897,
|
||||
0.39,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: indentation_preference | Observation: The user claims to prefer tabs but their code uses spaces",
|
||||
),
|
||||
(
|
||||
0.3833,
|
||||
0.383,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: primary_languages | Observation: The user primarily works with Python and JavaScript",
|
||||
),
|
||||
(
|
||||
0.3761,
|
||||
0.376,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: editor_preference | Observation: The user prefers Vim over VS Code for editing",
|
||||
),
|
||||
],
|
||||
@ -608,11 +608,11 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"semantic": [
|
||||
(0.536, "The user claims to prefer tabs but their code uses spaces"),
|
||||
(
|
||||
0.5353,
|
||||
0.535,
|
||||
"Subject: indentation_preference | Type: contradiction | Observation: The user claims to prefer tabs but their code uses spaces | Quote: Tabs are better than spaces vs code consistently uses 2-space indentation",
|
||||
),
|
||||
(
|
||||
0.5328,
|
||||
0.533,
|
||||
"Subject: pure_functions | Type: contradiction | Observation: The user said pure functions are yucky | Quote: Pure functions are yucky",
|
||||
),
|
||||
(
|
||||
@ -622,19 +622,19 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.4671,
|
||||
0.467,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
(
|
||||
0.4661,
|
||||
0.466,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: indentation_preference | Observation: The user claims to prefer tabs but their code uses spaces",
|
||||
),
|
||||
(
|
||||
0.4566,
|
||||
0.457,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: pure_functions | Observation: The user said pure functions are yucky",
|
||||
),
|
||||
(
|
||||
0.4553,
|
||||
0.455,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: database_preference | Observation: The user prefers PostgreSQL over MongoDB for most applications",
|
||||
),
|
||||
],
|
||||
@ -642,15 +642,15 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What does the user think about software testing?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.6384,
|
||||
0.638,
|
||||
"Subject: testing_philosophy | Type: belief | Observation: The user believes unit tests are a waste of time for prototypes | Quote: Writing tests for throwaway code slows development",
|
||||
),
|
||||
(0.6219, "The user believes unit tests are a waste of time for prototypes"),
|
||||
(0.622, "The user believes unit tests are a waste of time for prototypes"),
|
||||
(
|
||||
0.6154,
|
||||
0.615,
|
||||
"Subject: code_quality | Type: belief | Observation: The user believes code reviews are essential for quality | Quote: Code reviews catch bugs that automated testing misses",
|
||||
),
|
||||
(0.6031, "The user believes code reviews are essential for quality"),
|
||||
(0.603, "The user believes code reviews are essential for quality"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
@ -658,15 +658,15 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: testing_philosophy | Observation: The user believes unit tests are a waste of time for prototypes",
|
||||
),
|
||||
(
|
||||
0.4901,
|
||||
0.49,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: code_quality | Observation: The user believes code reviews are essential for quality",
|
||||
),
|
||||
(
|
||||
0.4745,
|
||||
0.474,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
0.4524,
|
||||
0.452,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: debugging_approach | Observation: The user debugs by adding print statements rather than using a debugger",
|
||||
),
|
||||
],
|
||||
@ -678,30 +678,30 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Subject: documentation_habits | Type: behavior | Observation: The user always writes documentation before implementing features | Quote: I document the API design before writing any code",
|
||||
),
|
||||
(
|
||||
0.5462,
|
||||
0.546,
|
||||
"The user always writes documentation before implementing features",
|
||||
),
|
||||
(0.5213, "I document the API design before writing any code"),
|
||||
(0.521, "I document the API design before writing any code"),
|
||||
(
|
||||
0.4949,
|
||||
0.495,
|
||||
"Subject: debugging_approach | Type: behavior | Observation: The user debugs by adding print statements rather than using a debugger | Quote: When debugging, I just add console.log everywhere",
|
||||
),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5001,
|
||||
0.5,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: documentation_habits | Observation: The user always writes documentation before implementing features",
|
||||
),
|
||||
(
|
||||
0.4371,
|
||||
0.437,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
|
||||
),
|
||||
(
|
||||
0.4355,
|
||||
0.435,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: indentation_preference | Observation: The user claims to prefer tabs but their code uses spaces",
|
||||
),
|
||||
(
|
||||
0.4347,
|
||||
0.435,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
],
|
||||
@ -709,12 +709,12 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What are the user's collaboration preferences?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.6516,
|
||||
0.652,
|
||||
"Subject: collaboration_preference | Type: preference | Observation: The user prefers pair programming for complex problems | Quote: Two heads are better than one when solving hard problems",
|
||||
),
|
||||
(0.5855, "The user prefers pair programming for complex problems"),
|
||||
(0.585, "The user prefers pair programming for complex problems"),
|
||||
(
|
||||
0.5361,
|
||||
0.536,
|
||||
"Subject: version_control_style | Type: preference | Observation: The user prefers small, focused commits over large feature branches | Quote: I like to commit small, logical changes frequently",
|
||||
),
|
||||
(
|
||||
@ -724,7 +724,7 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5889,
|
||||
0.589,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: collaboration_preference | Observation: The user prefers pair programming for complex problems",
|
||||
),
|
||||
(
|
||||
@ -732,40 +732,40 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: version_control_style | Observation: The user prefers small, focused commits over large feature branches",
|
||||
),
|
||||
(
|
||||
0.4754,
|
||||
0.475,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
(
|
||||
0.4638,
|
||||
0.464,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: work_environment | Observation: The user thinks remote work is more productive than office work",
|
||||
),
|
||||
],
|
||||
},
|
||||
"What does the user think about remote work?": {
|
||||
"semantic": [
|
||||
(0.7054, "The user thinks remote work is more productive than office work"),
|
||||
(0.705, "The user thinks remote work is more productive than office work"),
|
||||
(
|
||||
0.6581,
|
||||
0.658,
|
||||
"Subject: work_environment | Type: belief | Observation: The user thinks remote work is more productive than office work | Quote: I get more done working from home",
|
||||
),
|
||||
(0.6026, "I get more done working from home"),
|
||||
(0.4991, "The user prefers working on backend systems over frontend UI"),
|
||||
(0.603, "I get more done working from home"),
|
||||
(0.499, "The user prefers working on backend systems over frontend UI"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5832,
|
||||
0.583,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: work_environment | Observation: The user thinks remote work is more productive than office work",
|
||||
),
|
||||
(
|
||||
0.4126,
|
||||
0.413,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: testing_philosophy | Observation: The user believes unit tests are a waste of time for prototypes",
|
||||
),
|
||||
(
|
||||
0.4122,
|
||||
0.412,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: collaboration_preference | Observation: The user prefers pair programming for complex problems",
|
||||
),
|
||||
(
|
||||
0.4092,
|
||||
0.409,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
],
|
||||
@ -773,27 +773,27 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What are the user's productivity methods?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.5729,
|
||||
0.573,
|
||||
"Subject: productivity_methods | Type: behavior | Observation: The user takes breaks every 25 minutes using the Pomodoro technique | Quote: I use 25-minute work intervals with 5-minute breaks",
|
||||
),
|
||||
(
|
||||
0.5261,
|
||||
0.526,
|
||||
"The user takes breaks every 25 minutes using the Pomodoro technique",
|
||||
),
|
||||
(0.5205, "I use 25-minute work intervals with 5-minute breaks"),
|
||||
(0.52, "I use 25-minute work intervals with 5-minute breaks"),
|
||||
(0.512, "The user thinks remote work is more productive than office work"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5312,
|
||||
0.531,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: productivity_methods | Observation: The user takes breaks every 25 minutes using the Pomodoro technique",
|
||||
),
|
||||
(
|
||||
0.4796,
|
||||
0.48,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: work_environment | Observation: The user thinks remote work is more productive than office work",
|
||||
),
|
||||
(
|
||||
0.4344,
|
||||
0.434,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: collaboration_preference | Observation: The user prefers pair programming for complex problems",
|
||||
),
|
||||
(
|
||||
@ -804,17 +804,17 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
},
|
||||
"What technical skills is the user learning?": {
|
||||
"semantic": [
|
||||
(0.5766, "The user is currently learning Rust in their spare time"),
|
||||
(0.577, "The user is currently learning Rust in their spare time"),
|
||||
(
|
||||
0.55,
|
||||
"Subject: learning_activities | Type: general | Observation: The user is currently learning Rust in their spare time | Quote: I'm picking up Rust on weekends",
|
||||
),
|
||||
(0.5415, "I'm picking up Rust on weekends"),
|
||||
(0.5156, "The user primarily works with Python and JavaScript"),
|
||||
(0.542, "I'm picking up Rust on weekends"),
|
||||
(0.516, "The user primarily works with Python and JavaScript"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.5221,
|
||||
0.522,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: learning_activities | Observation: The user is currently learning Rust in their spare time",
|
||||
),
|
||||
(
|
||||
@ -822,32 +822,32 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: primary_languages | Observation: The user primarily works with Python and JavaScript",
|
||||
),
|
||||
(
|
||||
0.4871,
|
||||
0.487,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: experience_level | Observation: The user has 8 years of professional programming experience",
|
||||
),
|
||||
(
|
||||
0.4547,
|
||||
0.455,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: education_background | Observation: The user graduated with a Computer Science degree from Stanford",
|
||||
),
|
||||
],
|
||||
},
|
||||
"What does the user think about cooking?": {
|
||||
"semantic": [
|
||||
(0.4893, "I find backend logic more interesting than UI work"),
|
||||
(0.4621, "The user prefers working on backend systems over frontend UI"),
|
||||
(0.489, "I find backend logic more interesting than UI work"),
|
||||
(0.462, "The user prefers working on backend systems over frontend UI"),
|
||||
(
|
||||
0.4551,
|
||||
0.455,
|
||||
"The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(0.4549, "The user said pure functions are yucky"),
|
||||
(0.455, "The user said pure functions are yucky"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.3785,
|
||||
0.379,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: pure_functions | Observation: The user said pure functions are yucky",
|
||||
),
|
||||
(
|
||||
0.3759,
|
||||
0.376,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_philosophy | Observation: The user believes functional programming leads to better code quality",
|
||||
),
|
||||
(
|
||||
@ -855,7 +855,7 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: typescript_opinion | Observation: The user now says they love TypeScript but previously called it verbose",
|
||||
),
|
||||
(
|
||||
0.3594,
|
||||
0.359,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
],
|
||||
@ -866,25 +866,25 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
0.523,
|
||||
"Subject: domain_preference | Type: preference | Observation: The user prefers working on backend systems over frontend UI | Quote: I find backend logic more interesting than UI work",
|
||||
),
|
||||
(0.5143, "The user prefers functional programming over OOP"),
|
||||
(0.5074, "The user prefers working on backend systems over frontend UI"),
|
||||
(0.5049, "The user prefers working late at night"),
|
||||
(0.514, "The user prefers functional programming over OOP"),
|
||||
(0.507, "The user prefers working on backend systems over frontend UI"),
|
||||
(0.505, "The user prefers working late at night"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.4767,
|
||||
0.477,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: domain_preference | Observation: The user prefers working on backend systems over frontend UI",
|
||||
),
|
||||
(
|
||||
0.4748,
|
||||
0.475,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: database_preference | Observation: The user prefers PostgreSQL over MongoDB for most applications",
|
||||
),
|
||||
(
|
||||
0.4587,
|
||||
0.459,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_paradigms | Observation: The user prefers functional programming over OOP",
|
||||
),
|
||||
(
|
||||
0.4554,
|
||||
0.455,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: collaboration_preference | Observation: The user prefers pair programming for complex problems",
|
||||
),
|
||||
],
|
||||
@ -892,28 +892,28 @@ EXPECTED_OBSERVATION_RESULTS = {
|
||||
"What music does the user like?": {
|
||||
"semantic": [
|
||||
(
|
||||
0.4933,
|
||||
0.493,
|
||||
"Subject: domain_preference | Type: preference | Observation: The user prefers working on backend systems over frontend UI | Quote: I find backend logic more interesting than UI work",
|
||||
),
|
||||
(0.4906, "The user prefers working late at night"),
|
||||
(0.4902, "The user prefers functional programming over OOP"),
|
||||
(0.4894, "The user primarily works with Python and JavaScript"),
|
||||
(0.491, "The user prefers working late at night"),
|
||||
(0.49, "The user prefers functional programming over OOP"),
|
||||
(0.489, "The user primarily works with Python and JavaScript"),
|
||||
],
|
||||
"temporal": [
|
||||
(
|
||||
0.4676,
|
||||
0.468,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: typescript_opinion | Observation: The user now says they love TypeScript but previously called it verbose",
|
||||
),
|
||||
(
|
||||
0.4561,
|
||||
0.456,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: primary_languages | Observation: The user primarily works with Python and JavaScript",
|
||||
),
|
||||
(
|
||||
0.4471,
|
||||
0.447,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: programming_paradigms | Observation: The user prefers functional programming over OOP",
|
||||
),
|
||||
(
|
||||
0.4432,
|
||||
0.443,
|
||||
"Time: 12:00 on Wednesday (afternoon) | Subject: editor_preference | Observation: The user prefers Vim over VS Code for editing",
|
||||
),
|
||||
],
|
||||
@ -1104,15 +1104,20 @@ def test_real_observation_embeddings(real_voyage_client, qdrant):
|
||||
def get_top(vector, search_type: str) -> list[tuple[float, str]]:
|
||||
results = qdrant_tools.search_vectors(qdrant, search_type, vector)
|
||||
return [
|
||||
(round(i.score, 4), chunk_map[str(i.id)].content)
|
||||
(pytest.approx(i.score, 0.1), chunk_map[str(i.id)].content) # type: ignore
|
||||
for i in sorted(results, key=lambda x: x.score, reverse=True)
|
||||
][:4]
|
||||
|
||||
results = {}
|
||||
for query, expected in EXPECTED_OBSERVATION_RESULTS.items():
|
||||
search_vector = embed_text(
|
||||
[extract.DataChunk(data=[query])], input_type="query"
|
||||
)[0]
|
||||
semantic_results = get_top(search_vector, "semantic")
|
||||
temporal_results = get_top(search_vector, "temporal")
|
||||
results[query] = {
|
||||
"semantic": semantic_results,
|
||||
"temporal": temporal_results,
|
||||
}
|
||||
assert semantic_results == expected["semantic"]
|
||||
assert temporal_results == expected["temporal"]
|
||||
|
174
tests/memory/api/search/test_search_embeddings.py
Normal file
174
tests/memory/api/search/test_search_embeddings.py
Normal file
@ -0,0 +1,174 @@
|
||||
import pytest
|
||||
from memory.api.search.embeddings import merge_range_filter, merge_filters
|
||||
|
||||
|
||||
def test_merge_range_filter_new_filter():
|
||||
"""Test adding new range filters"""
|
||||
filters = []
|
||||
result = merge_range_filter(filters, "min_size", 100)
|
||||
assert result == [{"key": "size", "range": {"gte": 100}}]
|
||||
|
||||
filters = []
|
||||
result = merge_range_filter(filters, "max_size", 1000)
|
||||
assert result == [{"key": "size", "range": {"lte": 1000}}]
|
||||
|
||||
|
||||
def test_merge_range_filter_existing_field():
|
||||
"""Test adding to existing field"""
|
||||
filters = [{"key": "size", "range": {"lte": 1000}}]
|
||||
result = merge_range_filter(filters, "min_size", 100)
|
||||
assert result == [{"key": "size", "range": {"lte": 1000, "gte": 100}}]
|
||||
|
||||
|
||||
def test_merge_range_filter_override_existing():
|
||||
"""Test overriding existing values"""
|
||||
filters = [{"key": "size", "range": {"gte": 100}}]
|
||||
result = merge_range_filter(filters, "min_size", 200)
|
||||
assert result == [{"key": "size", "range": {"gte": 200}}]
|
||||
|
||||
|
||||
def test_merge_range_filter_with_other_filters():
|
||||
"""Test adding range filter alongside other filters"""
|
||||
filters = [{"key": "tags", "match": {"any": ["tag1"]}}]
|
||||
result = merge_range_filter(filters, "min_size", 100)
|
||||
|
||||
expected = [
|
||||
{"key": "tags", "match": {"any": ["tag1"]}},
|
||||
{"key": "size", "range": {"gte": 100}},
|
||||
]
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key,expected_direction,expected_field",
|
||||
[
|
||||
("min_sent_at", "min", "sent_at"),
|
||||
("max_sent_at", "max", "sent_at"),
|
||||
("min_published", "min", "published"),
|
||||
("max_published", "max", "published"),
|
||||
("min_size", "min", "size"),
|
||||
("max_size", "max", "size"),
|
||||
],
|
||||
)
|
||||
def test_merge_range_filter_key_parsing(key, expected_direction, expected_field):
|
||||
"""Test that field names are correctly extracted from keys"""
|
||||
filters = []
|
||||
result = merge_range_filter(filters, key, 100)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["key"] == expected_field
|
||||
range_key = "gte" if expected_direction == "min" else "lte"
|
||||
assert result[0]["range"][range_key] == 100
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_key,filter_value",
|
||||
[
|
||||
("tags", ["tag1", "tag2"]),
|
||||
("recipients", ["user1", "user2"]),
|
||||
("observation_types", ["belief", "preference"]),
|
||||
("authors", ["author1"]),
|
||||
],
|
||||
)
|
||||
def test_merge_filters_list_filters(filter_key, filter_value):
|
||||
"""Test list filters that use match any"""
|
||||
filters = []
|
||||
result = merge_filters(filters, filter_key, filter_value)
|
||||
expected = [{"key": filter_key, "match": {"any": filter_value}}]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_merge_filters_min_confidences():
|
||||
"""Test min_confidences filter creates multiple range conditions"""
|
||||
filters = []
|
||||
confidences = {"observation_accuracy": 0.8, "source_reliability": 0.9}
|
||||
result = merge_filters(filters, "min_confidences", confidences)
|
||||
|
||||
expected = [
|
||||
{"key": "confidence.observation_accuracy", "range": {"gte": 0.8}},
|
||||
{"key": "confidence.source_reliability", "range": {"gte": 0.9}},
|
||||
]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_merge_filters_source_ids():
|
||||
"""Test source_ids filter maps to id field"""
|
||||
filters = []
|
||||
result = merge_filters(filters, "source_ids", ["id1", "id2"])
|
||||
expected = [{"key": "id", "match": {"any": ["id1", "id2"]}}]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_merge_filters_range_delegation():
|
||||
"""Test range filters are properly delegated to merge_range_filter"""
|
||||
filters = []
|
||||
result = merge_filters(filters, "min_size", 100)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "range" in result[0]
|
||||
assert result[0]["range"]["gte"] == 100
|
||||
|
||||
|
||||
def test_merge_filters_combined_range():
|
||||
"""Test that min/max range pairs merge into single filter"""
|
||||
filters = []
|
||||
filters = merge_filters(filters, "min_size", 100)
|
||||
filters = merge_filters(filters, "max_size", 1000)
|
||||
|
||||
size_filters = [f for f in filters if f["key"] == "size"]
|
||||
assert len(size_filters) == 1
|
||||
assert size_filters[0]["range"]["gte"] == 100
|
||||
assert size_filters[0]["range"]["lte"] == 1000
|
||||
|
||||
|
||||
def test_merge_filters_preserves_existing():
|
||||
"""Test that existing filters are preserved when adding new ones"""
|
||||
existing_filters = [{"key": "existing", "match": "value"}]
|
||||
result = merge_filters(existing_filters, "tags", ["new_tag"])
|
||||
|
||||
assert len(result) == 2
|
||||
assert {"key": "existing", "match": "value"} in result
|
||||
assert {"key": "tags", "match": {"any": ["new_tag"]}} in result
|
||||
|
||||
|
||||
def test_merge_filters_realistic_combination():
|
||||
"""Test a realistic filter combination for knowledge base search"""
|
||||
filters = []
|
||||
|
||||
# Add typical knowledge base filters
|
||||
filters = merge_filters(filters, "tags", ["important", "work"])
|
||||
filters = merge_filters(filters, "min_published", "2023-01-01")
|
||||
filters = merge_filters(filters, "max_size", 1000000) # 1MB max
|
||||
filters = merge_filters(filters, "min_confidences", {"observation_accuracy": 0.8})
|
||||
|
||||
assert len(filters) == 4
|
||||
|
||||
# Check each filter type
|
||||
tag_filter = next(f for f in filters if f["key"] == "tags")
|
||||
assert tag_filter["match"]["any"] == ["important", "work"]
|
||||
|
||||
published_filter = next(f for f in filters if f["key"] == "published")
|
||||
assert published_filter["range"]["gte"] == "2023-01-01"
|
||||
|
||||
size_filter = next(f for f in filters if f["key"] == "size")
|
||||
assert size_filter["range"]["lte"] == 1000000
|
||||
|
||||
confidence_filter = next(
|
||||
f for f in filters if f["key"] == "confidence.observation_accuracy"
|
||||
)
|
||||
assert confidence_filter["range"]["gte"] == 0.8
|
||||
|
||||
|
||||
def test_merge_filters_unknown_key():
|
||||
"""Test fallback behavior for unknown filter keys"""
|
||||
filters = []
|
||||
result = merge_filters(filters, "unknown_field", "unknown_value")
|
||||
expected = [{"key": "unknown_field", "match": "unknown_value"}]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_merge_filters_empty_min_confidences():
|
||||
"""Test min_confidences with empty dict does nothing"""
|
||||
filters = []
|
||||
result = merge_filters(filters, "min_confidences", {})
|
||||
assert result == []
|
@ -198,6 +198,14 @@ def test_email_attachment_embeddings_text(mock_voyage_client):
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
mail_message=MailMessage(
|
||||
sent_at=datetime(2025, 1, 1, 12, 0, 0),
|
||||
message_id="123",
|
||||
subject="Test",
|
||||
sender="john.doe@techcorp.com",
|
||||
recipients=["john.doe@techcorp.com"],
|
||||
folder="INBOX",
|
||||
),
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
@ -238,6 +246,14 @@ def test_email_attachment_embeddings_photo(mock_voyage_client):
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
mail_message=MailMessage(
|
||||
sent_at=datetime(2025, 1, 1, 12, 0, 0),
|
||||
message_id="123",
|
||||
subject="Test",
|
||||
sender="john.doe@techcorp.com",
|
||||
recipients=["john.doe@techcorp.com"],
|
||||
folder="INBOX",
|
||||
),
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
@ -275,6 +291,14 @@ def test_email_attachment_embeddings_pdf(mock_voyage_client):
|
||||
sha256=hashlib.sha256(SAMPLE_MARKDOWN.encode("utf-8")).hexdigest(),
|
||||
size=len(SAMPLE_MARKDOWN),
|
||||
tags=["bla"],
|
||||
mail_message=MailMessage(
|
||||
sent_at=datetime(2025, 1, 1, 12, 0, 0),
|
||||
message_id="123",
|
||||
subject="Test",
|
||||
sender="john.doe@techcorp.com",
|
||||
recipients=["john.doe@techcorp.com"],
|
||||
folder="INBOX",
|
||||
),
|
||||
)
|
||||
metadata = item.as_payload()
|
||||
metadata["tags"] = {"bla"}
|
||||
@ -314,7 +338,7 @@ def test_email_attachment_embeddings_pdf(mock_voyage_client):
|
||||
] == [page for _, page, _ in expected]
|
||||
|
||||
|
||||
def test_email_attachment_embeddings_comic(mock_voyage_client):
|
||||
def test_embeddings_comic(mock_voyage_client):
|
||||
item = Comic(
|
||||
id=1,
|
||||
content=SAMPLE_MARKDOWN,
|
||||
|
@ -223,6 +223,14 @@ def test_email_attachment_as_payload(created_at, expected_date):
|
||||
mail_message_id=123,
|
||||
created_at=created_at,
|
||||
tags=["pdf", "document"],
|
||||
mail_message=MailMessage(
|
||||
sent_at=datetime(2025, 1, 1, 12, 0, 0),
|
||||
message_id="123",
|
||||
subject="Test",
|
||||
sender="john.doe@techcorp.com",
|
||||
recipients=["john.doe@techcorp.com"],
|
||||
folder="INBOX",
|
||||
),
|
||||
)
|
||||
# Manually set id for testing
|
||||
object.__setattr__(attachment, "id", 456)
|
||||
@ -237,6 +245,7 @@ def test_email_attachment_as_payload(created_at, expected_date):
|
||||
"created_at": expected_date,
|
||||
"mail_message_id": 123,
|
||||
"tags": ["pdf", "document"],
|
||||
"sent_at": "2025-01-01T12:00:00",
|
||||
}
|
||||
assert payload == expected
|
||||
|
||||
|
@ -219,6 +219,9 @@ def test_sync_comic_success(mock_get, mock_image_response, db_session, qdrant):
|
||||
"url": "https://example.com/comic/1",
|
||||
"source_id": 1,
|
||||
"size": 90,
|
||||
"issue": None,
|
||||
"volume": None,
|
||||
"page": None,
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user