mirror of
https://github.com/mruwnik/memory.git
synced 2026-01-02 09:12:58 +01:00
exclude stuff
This commit is contained in:
parent
59c45ff1fb
commit
9cf71c9336
@ -0,0 +1,35 @@
|
|||||||
|
"""Add exclude_folder_ids to google_folders
|
||||||
|
|
||||||
|
Revision ID: add_exclude_folder_ids
|
||||||
|
Revises: 20251229_120000_add_google_drive
|
||||||
|
Create Date: 2025-12-29 15:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "add_exclude_folder_ids"
|
||||||
|
down_revision: Union[str, None] = "f2a3b4c5d6e7"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"google_folders",
|
||||||
|
sa.Column(
|
||||||
|
"exclude_folder_ids",
|
||||||
|
postgresql.ARRAY(sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="{}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("google_folders", "exclude_folder_ids")
|
||||||
@ -1935,4 +1935,98 @@ a.folder-item-name:hover {
|
|||||||
|
|
||||||
.add-btn.secondary:hover {
|
.add-btn.secondary:hover {
|
||||||
background: #edf2f7;
|
background: #edf2f7;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* === Exclusion Browser === */
|
||||||
|
|
||||||
|
.exclusion-browser {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
min-height: 400px;
|
||||||
|
max-height: 60vh;
|
||||||
|
}
|
||||||
|
|
||||||
|
.excluded-list {
|
||||||
|
background: #fffbeb;
|
||||||
|
border: 1px solid #f59e0b;
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 0.75rem;
|
||||||
|
margin-bottom: 0.75rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.excluded-list h5 {
|
||||||
|
font-size: 0.85rem;
|
||||||
|
color: #92400e;
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.excluded-items {
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.excluded-item {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.25rem;
|
||||||
|
background: #fef3c7;
|
||||||
|
padding: 0.25rem 0.5rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.excluded-name {
|
||||||
|
color: #92400e;
|
||||||
|
}
|
||||||
|
|
||||||
|
.remove-exclusion-btn {
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
color: #b45309;
|
||||||
|
cursor: pointer;
|
||||||
|
padding: 0 0.25rem;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
line-height: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.remove-exclusion-btn:hover {
|
||||||
|
color: #dc2626;
|
||||||
|
}
|
||||||
|
|
||||||
|
.folder-item.excluded {
|
||||||
|
background: #fffbeb;
|
||||||
|
border-color: #f59e0b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.exclusion-badge {
|
||||||
|
background: #f59e0b;
|
||||||
|
color: white;
|
||||||
|
padding: 0.125rem 0.375rem;
|
||||||
|
border-radius: 3px;
|
||||||
|
font-size: 0.7rem;
|
||||||
|
margin-left: auto;
|
||||||
|
margin-right: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.exclusions-btn {
|
||||||
|
background: #fffbeb;
|
||||||
|
color: #b45309;
|
||||||
|
border: 1px solid #f59e0b;
|
||||||
|
padding: 0.25rem 0.5rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.15s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.exclusions-btn:hover {
|
||||||
|
background: #fef3c7;
|
||||||
|
color: #92400e;
|
||||||
|
}
|
||||||
|
|
||||||
|
.setting-badge.warning {
|
||||||
|
background: #fffbeb;
|
||||||
|
color: #b45309;
|
||||||
|
border: 1px solid #f59e0b;
|
||||||
}
|
}
|
||||||
@ -1,6 +1,6 @@
|
|||||||
import { useState, useEffect, useCallback } from 'react'
|
import { useState, useEffect, useCallback } from 'react'
|
||||||
import { Link } from 'react-router-dom'
|
import { Link } from 'react-router-dom'
|
||||||
import { useSources, EmailAccount, ArticleFeed, GithubAccount, GoogleAccount, GoogleOAuthConfig, DriveItem, BrowseResponse, GoogleFolderCreate } from '@/hooks/useSources'
|
import { useSources, EmailAccount, ArticleFeed, GithubAccount, GoogleAccount, GoogleFolder, GoogleOAuthConfig, DriveItem, BrowseResponse, GoogleFolderCreate } from '@/hooks/useSources'
|
||||||
import {
|
import {
|
||||||
SourceCard,
|
SourceCard,
|
||||||
Modal,
|
Modal,
|
||||||
@ -956,6 +956,7 @@ const GoogleDrivePanel = () => {
|
|||||||
const [error, setError] = useState<string | null>(null)
|
const [error, setError] = useState<string | null>(null)
|
||||||
const [addingFolderTo, setAddingFolderTo] = useState<number | null>(null)
|
const [addingFolderTo, setAddingFolderTo] = useState<number | null>(null)
|
||||||
const [browsingFoldersFor, setBrowsingFoldersFor] = useState<number | null>(null)
|
const [browsingFoldersFor, setBrowsingFoldersFor] = useState<number | null>(null)
|
||||||
|
const [managingExclusionsFor, setManagingExclusionsFor] = useState<{ accountId: number; folder: GoogleFolder } | null>(null)
|
||||||
const [uploadingConfig, setUploadingConfig] = useState(false)
|
const [uploadingConfig, setUploadingConfig] = useState(false)
|
||||||
|
|
||||||
const loadData = useCallback(async () => {
|
const loadData = useCallback(async () => {
|
||||||
@ -1044,6 +1045,12 @@ const GoogleDrivePanel = () => {
|
|||||||
loadData()
|
loadData()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleUpdateExclusions = async (accountId: number, folderId: number, excludeIds: string[]) => {
|
||||||
|
await updateGoogleFolder(accountId, folderId, { exclude_folder_ids: excludeIds })
|
||||||
|
setManagingExclusionsFor(null)
|
||||||
|
loadData()
|
||||||
|
}
|
||||||
|
|
||||||
const handleDeleteFolder = async (accountId: number, folderId: number) => {
|
const handleDeleteFolder = async (accountId: number, folderId: number) => {
|
||||||
await deleteGoogleFolder(accountId, folderId)
|
await deleteGoogleFolder(accountId, folderId)
|
||||||
loadData()
|
loadData()
|
||||||
@ -1188,6 +1195,11 @@ const GoogleDrivePanel = () => {
|
|||||||
<div className="folder-settings">
|
<div className="folder-settings">
|
||||||
{folder.recursive && <span className="setting-badge">Recursive</span>}
|
{folder.recursive && <span className="setting-badge">Recursive</span>}
|
||||||
{folder.include_shared && <span className="setting-badge">Shared</span>}
|
{folder.include_shared && <span className="setting-badge">Shared</span>}
|
||||||
|
{folder.exclude_folder_ids.length > 0 && (
|
||||||
|
<span className="setting-badge warning">
|
||||||
|
{folder.exclude_folder_ids.length} excluded
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="folder-actions">
|
<div className="folder-actions">
|
||||||
<SyncButton
|
<SyncButton
|
||||||
@ -1195,6 +1207,14 @@ const GoogleDrivePanel = () => {
|
|||||||
disabled={!folder.active || !account.active}
|
disabled={!folder.active || !account.active}
|
||||||
label="Sync"
|
label="Sync"
|
||||||
/>
|
/>
|
||||||
|
{folder.recursive && (
|
||||||
|
<button
|
||||||
|
className="exclusions-btn"
|
||||||
|
onClick={() => setManagingExclusionsFor({ accountId: account.id, folder })}
|
||||||
|
>
|
||||||
|
Exclusions
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
<button
|
<button
|
||||||
className="toggle-btn"
|
className="toggle-btn"
|
||||||
onClick={() => handleToggleFolderActive(account.id, folder.id, folder.active)}
|
onClick={() => handleToggleFolderActive(account.id, folder.id, folder.active)}
|
||||||
@ -1233,6 +1253,15 @@ const GoogleDrivePanel = () => {
|
|||||||
onCancel={() => setBrowsingFoldersFor(null)}
|
onCancel={() => setBrowsingFoldersFor(null)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{managingExclusionsFor && (
|
||||||
|
<ExclusionBrowser
|
||||||
|
accountId={managingExclusionsFor.accountId}
|
||||||
|
folder={managingExclusionsFor.folder}
|
||||||
|
onSave={(excludeIds) => handleUpdateExclusions(managingExclusionsFor.accountId, managingExclusionsFor.folder.id, excludeIds)}
|
||||||
|
onCancel={() => setManagingExclusionsFor(null)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -1422,6 +1451,195 @@ const FolderBrowser = ({ accountId, onSelect, onCancel }: FolderBrowserProps) =>
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === Exclusion Browser ===
|
||||||
|
|
||||||
|
interface ExclusionBrowserProps {
|
||||||
|
accountId: number
|
||||||
|
folder: GoogleFolder
|
||||||
|
onSave: (excludeIds: string[]) => void
|
||||||
|
onCancel: () => void
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ExcludedFolder {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
path: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const ExclusionBrowser = ({ accountId, folder, onSave, onCancel }: ExclusionBrowserProps) => {
|
||||||
|
const { browseGoogleDrive } = useSources()
|
||||||
|
const [path, setPath] = useState<PathItem[]>([{ id: folder.folder_id, name: folder.folder_name }])
|
||||||
|
const [items, setItems] = useState<DriveItem[]>([])
|
||||||
|
const [excluded, setExcluded] = useState<Map<string, ExcludedFolder>>(() => {
|
||||||
|
// Initialize with current exclusions (we don't have names, so use ID as name)
|
||||||
|
const map = new Map<string, ExcludedFolder>()
|
||||||
|
for (const id of folder.exclude_folder_ids) {
|
||||||
|
map.set(id, { id, name: id, path: '(previously excluded)' })
|
||||||
|
}
|
||||||
|
return map
|
||||||
|
})
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
|
const [error, setError] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const currentFolderId = path[path.length - 1].id
|
||||||
|
const currentPath = path.map(p => p.name).join(' > ')
|
||||||
|
|
||||||
|
const loadFolder = useCallback(async (folderId: string) => {
|
||||||
|
setLoading(true)
|
||||||
|
setError(null)
|
||||||
|
try {
|
||||||
|
const response = await browseGoogleDrive(accountId, folderId)
|
||||||
|
// Only show folders, not files
|
||||||
|
setItems(response.items.filter(item => item.is_folder))
|
||||||
|
} catch (e) {
|
||||||
|
setError(e instanceof Error ? e.message : 'Failed to load folder')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}, [accountId, browseGoogleDrive])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
loadFolder(currentFolderId)
|
||||||
|
}, [currentFolderId, loadFolder])
|
||||||
|
|
||||||
|
const navigateToFolder = (item: DriveItem) => {
|
||||||
|
setPath([...path, { id: item.id, name: item.name }])
|
||||||
|
}
|
||||||
|
|
||||||
|
const navigateToPathIndex = (index: number) => {
|
||||||
|
setPath(path.slice(0, index + 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleExclude = (item: DriveItem) => {
|
||||||
|
const newExcluded = new Map(excluded)
|
||||||
|
if (newExcluded.has(item.id)) {
|
||||||
|
newExcluded.delete(item.id)
|
||||||
|
} else {
|
||||||
|
newExcluded.set(item.id, {
|
||||||
|
id: item.id,
|
||||||
|
name: item.name,
|
||||||
|
path: currentPath + ' > ' + item.name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
setExcluded(newExcluded)
|
||||||
|
}
|
||||||
|
|
||||||
|
const removeExclusion = (id: string) => {
|
||||||
|
const newExcluded = new Map(excluded)
|
||||||
|
newExcluded.delete(id)
|
||||||
|
setExcluded(newExcluded)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleSave = () => {
|
||||||
|
onSave(Array.from(excluded.keys()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal title={`Manage Exclusions: ${folder.folder_name}`} onClose={onCancel}>
|
||||||
|
<div className="exclusion-browser">
|
||||||
|
{/* Current exclusions */}
|
||||||
|
{excluded.size > 0 && (
|
||||||
|
<div className="excluded-list">
|
||||||
|
<h5>Excluded Folders ({excluded.size})</h5>
|
||||||
|
<div className="excluded-items">
|
||||||
|
{Array.from(excluded.values()).map(item => (
|
||||||
|
<div key={item.id} className="excluded-item">
|
||||||
|
<span className="excluded-name" title={item.path}>
|
||||||
|
📁 {item.name}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
className="remove-exclusion-btn"
|
||||||
|
onClick={() => removeExclusion(item.id)}
|
||||||
|
title="Remove exclusion"
|
||||||
|
>
|
||||||
|
✕
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Breadcrumb */}
|
||||||
|
<div className="folder-breadcrumb">
|
||||||
|
{path.map((item, index) => (
|
||||||
|
<span key={item.id}>
|
||||||
|
{index > 0 && <span className="breadcrumb-sep">></span>}
|
||||||
|
<button
|
||||||
|
className={`breadcrumb-item ${index === path.length - 1 ? 'current' : ''}`}
|
||||||
|
onClick={() => navigateToPathIndex(index)}
|
||||||
|
disabled={index === path.length - 1}
|
||||||
|
>
|
||||||
|
{item.name}
|
||||||
|
</button>
|
||||||
|
</span>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Content */}
|
||||||
|
{error && <div className="form-error">{error}</div>}
|
||||||
|
|
||||||
|
{loading ? (
|
||||||
|
<div className="folder-loading">Loading...</div>
|
||||||
|
) : items.length === 0 ? (
|
||||||
|
<div className="folder-empty">No subfolders in this folder</div>
|
||||||
|
) : (
|
||||||
|
<div className="folder-list">
|
||||||
|
{items.map(item => (
|
||||||
|
<div key={item.id} className={`folder-item ${excluded.has(item.id) ? 'excluded' : ''}`}>
|
||||||
|
<label className="folder-item-checkbox">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={excluded.has(item.id)}
|
||||||
|
onChange={() => toggleExclude(item)}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
<span className="folder-item-icon">📁</span>
|
||||||
|
<a
|
||||||
|
className="folder-item-name"
|
||||||
|
href={`https://drive.google.com/drive/folders/${item.id}`}
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
onClick={(e) => e.stopPropagation()}
|
||||||
|
>
|
||||||
|
{item.name}
|
||||||
|
</a>
|
||||||
|
{excluded.has(item.id) && (
|
||||||
|
<span className="exclusion-badge">Excluded</span>
|
||||||
|
)}
|
||||||
|
<button
|
||||||
|
className="folder-item-enter"
|
||||||
|
onClick={() => navigateToFolder(item)}
|
||||||
|
title="Browse subfolder"
|
||||||
|
>
|
||||||
|
→
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Footer */}
|
||||||
|
<div className="folder-browser-footer">
|
||||||
|
<span className="selected-count">
|
||||||
|
{excluded.size} folder{excluded.size !== 1 ? 's' : ''} excluded
|
||||||
|
</span>
|
||||||
|
<div className="folder-browser-actions">
|
||||||
|
<button type="button" className="cancel-btn" onClick={onCancel}>Cancel</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
className="submit-btn"
|
||||||
|
onClick={handleSave}
|
||||||
|
>
|
||||||
|
Save Exclusions
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
interface GoogleFolderFormProps {
|
interface GoogleFolderFormProps {
|
||||||
accountId: number
|
accountId: number
|
||||||
onSubmit: (data: any) => Promise<void>
|
onSubmit: (data: any) => Promise<void>
|
||||||
|
|||||||
@ -177,6 +177,7 @@ export interface GoogleFolder {
|
|||||||
check_interval: number
|
check_interval: number
|
||||||
last_sync_at: string | null
|
last_sync_at: string | null
|
||||||
active: boolean
|
active: boolean
|
||||||
|
exclude_folder_ids: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface GoogleAccount {
|
export interface GoogleAccount {
|
||||||
@ -205,6 +206,7 @@ export interface GoogleFolderUpdate {
|
|||||||
tags?: string[]
|
tags?: string[]
|
||||||
check_interval?: number
|
check_interval?: number
|
||||||
active?: boolean
|
active?: boolean
|
||||||
|
exclude_folder_ids?: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Types for Google Drive browsing
|
// Types for Google Drive browsing
|
||||||
|
|||||||
@ -29,6 +29,9 @@ from memory.common.db.models import (
|
|||||||
ScheduledLLMCall,
|
ScheduledLLMCall,
|
||||||
SourceItem,
|
SourceItem,
|
||||||
User,
|
User,
|
||||||
|
GoogleDoc,
|
||||||
|
GoogleFolder,
|
||||||
|
GoogleAccount,
|
||||||
)
|
)
|
||||||
from memory.common.db.models.discord import DiscordChannel, DiscordServer, DiscordUser
|
from memory.common.db.models.discord import DiscordChannel, DiscordServer, DiscordUser
|
||||||
|
|
||||||
@ -394,6 +397,42 @@ class GithubItemAdmin(ModelView, model=GithubItem):
|
|||||||
column_sortable_list = ["github_updated_at", "created_at"]
|
column_sortable_list = ["github_updated_at", "created_at"]
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleDocAdmin(ModelView, model=GoogleDoc):
|
||||||
|
column_list = source_columns(
|
||||||
|
GoogleDoc,
|
||||||
|
"title",
|
||||||
|
"folder_path",
|
||||||
|
"owner",
|
||||||
|
"last_modified_by",
|
||||||
|
"word_count",
|
||||||
|
"content_hash",
|
||||||
|
)
|
||||||
|
column_searchable_list = ["title", "folder_path", "owner", "last_modified_by", "id"]
|
||||||
|
column_sortable_list = ["google_modified_at", "created_at"]
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleFolderAdmin(ModelView, model=GoogleFolder):
|
||||||
|
column_list = source_columns(
|
||||||
|
GoogleFolder, "folder_name", "folder_path", "account", "active"
|
||||||
|
)
|
||||||
|
column_searchable_list = ["folder_name", "folder_path", "id"]
|
||||||
|
column_sortable_list = ["last_sync_at", "created_at"]
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleAccountAdmin(ModelView, model=GoogleAccount):
|
||||||
|
column_list = [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"email",
|
||||||
|
"active",
|
||||||
|
"last_sync_at",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
column_searchable_list = ["name", "email", "id"]
|
||||||
|
column_sortable_list = ["last_sync_at", "created_at"]
|
||||||
|
|
||||||
|
|
||||||
def setup_admin(admin: Admin):
|
def setup_admin(admin: Admin):
|
||||||
"""Add all admin views to the admin instance with OAuth protection."""
|
"""Add all admin views to the admin instance with OAuth protection."""
|
||||||
admin.add_view(SourceItemAdmin)
|
admin.add_view(SourceItemAdmin)
|
||||||
@ -421,3 +460,6 @@ def setup_admin(admin: Admin):
|
|||||||
admin.add_view(GithubAccountAdmin)
|
admin.add_view(GithubAccountAdmin)
|
||||||
admin.add_view(GithubRepoAdmin)
|
admin.add_view(GithubRepoAdmin)
|
||||||
admin.add_view(GithubItemAdmin)
|
admin.add_view(GithubItemAdmin)
|
||||||
|
admin.add_view(GoogleDocAdmin)
|
||||||
|
admin.add_view(GoogleFolderAdmin)
|
||||||
|
admin.add_view(GoogleAccountAdmin)
|
||||||
|
|||||||
@ -60,6 +60,7 @@ class FolderUpdate(BaseModel):
|
|||||||
tags: list[str] | None = None
|
tags: list[str] | None = None
|
||||||
check_interval: int | None = None
|
check_interval: int | None = None
|
||||||
active: bool | None = None
|
active: bool | None = None
|
||||||
|
exclude_folder_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class FolderResponse(BaseModel):
|
class FolderResponse(BaseModel):
|
||||||
@ -73,6 +74,7 @@ class FolderResponse(BaseModel):
|
|||||||
check_interval: int
|
check_interval: int
|
||||||
last_sync_at: str | None
|
last_sync_at: str | None
|
||||||
active: bool
|
active: bool
|
||||||
|
exclude_folder_ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
class AccountResponse(BaseModel):
|
class AccountResponse(BaseModel):
|
||||||
@ -370,6 +372,7 @@ def list_accounts(
|
|||||||
folder.last_sync_at.isoformat() if folder.last_sync_at else None
|
folder.last_sync_at.isoformat() if folder.last_sync_at else None
|
||||||
),
|
),
|
||||||
active=cast(bool, folder.active),
|
active=cast(bool, folder.active),
|
||||||
|
exclude_folder_ids=cast(list[str], folder.exclude_folder_ids) or [],
|
||||||
)
|
)
|
||||||
for folder in account.folders
|
for folder in account.folders
|
||||||
],
|
],
|
||||||
@ -531,6 +534,7 @@ def add_folder(
|
|||||||
check_interval=cast(int, new_folder.check_interval),
|
check_interval=cast(int, new_folder.check_interval),
|
||||||
last_sync_at=None,
|
last_sync_at=None,
|
||||||
active=cast(bool, new_folder.active),
|
active=cast(bool, new_folder.active),
|
||||||
|
exclude_folder_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -567,6 +571,8 @@ def update_folder(
|
|||||||
folder.check_interval = updates.check_interval
|
folder.check_interval = updates.check_interval
|
||||||
if updates.active is not None:
|
if updates.active is not None:
|
||||||
folder.active = updates.active
|
folder.active = updates.active
|
||||||
|
if updates.exclude_folder_ids is not None:
|
||||||
|
folder.exclude_folder_ids = updates.exclude_folder_ids
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(folder)
|
db.refresh(folder)
|
||||||
@ -584,6 +590,7 @@ def update_folder(
|
|||||||
folder.last_sync_at.isoformat() if folder.last_sync_at else None
|
folder.last_sync_at.isoformat() if folder.last_sync_at else None
|
||||||
),
|
),
|
||||||
active=cast(bool, folder.active),
|
active=cast(bool, folder.active),
|
||||||
|
exclude_folder_ids=cast(list[str], folder.exclude_folder_ids) or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -350,6 +350,9 @@ class GoogleFolder(Base):
|
|||||||
# File type filters (empty = all text documents)
|
# File type filters (empty = all text documents)
|
||||||
mime_type_filter = Column(ARRAY(Text), nullable=False, server_default="{}")
|
mime_type_filter = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
# Excluded subfolder IDs (skip these when syncing recursively)
|
||||||
|
exclude_folder_ids = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
# Tags to apply to all documents from this folder
|
# Tags to apply to all documents from this folder
|
||||||
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
tags = Column(ARRAY(Text), nullable=False, server_default="{}")
|
||||||
|
|
||||||
|
|||||||
@ -130,9 +130,19 @@ class GoogleDriveClient:
|
|||||||
recursive: bool = True,
|
recursive: bool = True,
|
||||||
since: datetime | None = None,
|
since: datetime | None = None,
|
||||||
page_size: int = 100,
|
page_size: int = 100,
|
||||||
|
exclude_folder_ids: set[str] | None = None,
|
||||||
) -> Generator[dict, None, None]:
|
) -> Generator[dict, None, None]:
|
||||||
"""List all supported files in a folder with pagination."""
|
"""List all supported files in a folder with pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_id: The Google Drive folder ID to list
|
||||||
|
recursive: Whether to recurse into subfolders
|
||||||
|
since: Only return files modified after this time
|
||||||
|
page_size: Number of files per API page
|
||||||
|
exclude_folder_ids: Set of folder IDs to skip during recursive traversal
|
||||||
|
"""
|
||||||
service = self._get_service()
|
service = self._get_service()
|
||||||
|
exclude_folder_ids = exclude_folder_ids or set()
|
||||||
|
|
||||||
# Build query for supported file types
|
# Build query for supported file types
|
||||||
all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES
|
all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES
|
||||||
@ -167,11 +177,16 @@ class GoogleDriveClient:
|
|||||||
|
|
||||||
for file in response.get("files", []):
|
for file in response.get("files", []):
|
||||||
if file["mimeType"] == "application/vnd.google-apps.folder":
|
if file["mimeType"] == "application/vnd.google-apps.folder":
|
||||||
if recursive:
|
if recursive and file["id"] not in exclude_folder_ids:
|
||||||
# Recursively list files in subfolder
|
# Recursively list files in subfolder
|
||||||
yield from self.list_files_in_folder(
|
yield from self.list_files_in_folder(
|
||||||
file["id"], recursive=True, since=since
|
file["id"],
|
||||||
|
recursive=True,
|
||||||
|
since=since,
|
||||||
|
exclude_folder_ids=exclude_folder_ids,
|
||||||
)
|
)
|
||||||
|
elif file["id"] in exclude_folder_ids:
|
||||||
|
logger.info(f"Skipping excluded folder: {file['name']} ({file['id']})")
|
||||||
else:
|
else:
|
||||||
yield file
|
yield file
|
||||||
|
|
||||||
|
|||||||
@ -251,10 +251,16 @@ def sync_google_folder(folder_id: int, force_full: bool = False) -> dict[str, An
|
|||||||
# It's a folder - list and sync all files inside
|
# It's a folder - list and sync all files inside
|
||||||
folder_path = client.get_folder_path(google_id)
|
folder_path = client.get_folder_path(google_id)
|
||||||
|
|
||||||
|
# Get excluded folder IDs
|
||||||
|
exclude_ids = set(cast(list[str], folder.exclude_folder_ids) or [])
|
||||||
|
if exclude_ids:
|
||||||
|
logger.info(f"Excluding {len(exclude_ids)} folder(s) from sync")
|
||||||
|
|
||||||
for file_meta in client.list_files_in_folder(
|
for file_meta in client.list_files_in_folder(
|
||||||
google_id,
|
google_id,
|
||||||
recursive=cast(bool, folder.recursive),
|
recursive=cast(bool, folder.recursive),
|
||||||
since=since,
|
since=since,
|
||||||
|
exclude_folder_ids=exclude_ids,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
file_data = client.fetch_file(file_meta, folder_path)
|
file_data = client.fetch_file(file_meta, folder_path)
|
||||||
|
|||||||
@ -5,10 +5,27 @@ import pytest
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from unittest.mock import AsyncMock, Mock, MagicMock, patch
|
from unittest.mock import AsyncMock, Mock, MagicMock, patch
|
||||||
|
|
||||||
# Mock the mcp module and all its submodules before importing anything that uses it
|
# Mock FastMCP - this creates a decorator factory that passes through the function unchanged
|
||||||
|
class MockFastMCP:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def tool(self):
|
||||||
|
def decorator(func):
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the fastmcp module before importing anything that uses it
|
||||||
|
_mock_fastmcp = MagicMock()
|
||||||
|
_mock_fastmcp.FastMCP = MockFastMCP
|
||||||
|
sys.modules["fastmcp"] = _mock_fastmcp
|
||||||
|
|
||||||
|
# Mock the mcp module and all its submodules
|
||||||
_mock_mcp = MagicMock()
|
_mock_mcp = MagicMock()
|
||||||
_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
|
_mock_mcp.tool = lambda: lambda f: f
|
||||||
sys.modules["mcp"] = _mock_mcp
|
sys.modules["mcp"] = _mock_mcp
|
||||||
|
sys.modules["mcp.types"] = MagicMock()
|
||||||
sys.modules["mcp.server"] = MagicMock()
|
sys.modules["mcp.server"] = MagicMock()
|
||||||
sys.modules["mcp.server.auth"] = MagicMock()
|
sys.modules["mcp.server.auth"] = MagicMock()
|
||||||
sys.modules["mcp.server.auth.handlers"] = MagicMock()
|
sys.modules["mcp.server.auth.handlers"] = MagicMock()
|
||||||
@ -21,7 +38,7 @@ sys.modules["mcp.server.fastmcp.server"] = MagicMock()
|
|||||||
# Also mock the memory.api.MCP.base module to avoid MCP imports
|
# Also mock the memory.api.MCP.base module to avoid MCP imports
|
||||||
_mock_base = MagicMock()
|
_mock_base = MagicMock()
|
||||||
_mock_base.mcp = MagicMock()
|
_mock_base.mcp = MagicMock()
|
||||||
_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
|
_mock_base.mcp.tool = lambda: lambda f: f
|
||||||
sys.modules["memory.api.MCP.base"] = _mock_base
|
sys.modules["memory.api.MCP.base"] = _mock_base
|
||||||
|
|
||||||
from memory.common.db.models import GithubItem
|
from memory.common.db.models import GithubItem
|
||||||
@ -189,7 +206,7 @@ def sample_issues(db_session):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_no_filters(db_session, sample_issues):
|
async def test_list_github_issues_no_filters(db_session, sample_issues):
|
||||||
"""Test listing all issues without filters."""
|
"""Test listing all issues without filters."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues()
|
results = await list_github_issues()
|
||||||
@ -203,7 +220,7 @@ async def test_list_github_issues_no_filters(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_repo(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_repo(db_session, sample_issues):
|
||||||
"""Test filtering by repository."""
|
"""Test filtering by repository."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(repo="owner/repo1")
|
results = await list_github_issues(repo="owner/repo1")
|
||||||
@ -215,7 +232,7 @@ async def test_list_github_issues_filter_by_repo(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_assignee(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_assignee(db_session, sample_issues):
|
||||||
"""Test filtering by assignee."""
|
"""Test filtering by assignee."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(assignee="alice")
|
results = await list_github_issues(assignee="alice")
|
||||||
@ -227,7 +244,7 @@ async def test_list_github_issues_filter_by_assignee(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_author(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_author(db_session, sample_issues):
|
||||||
"""Test filtering by author."""
|
"""Test filtering by author."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(author="alice")
|
results = await list_github_issues(author="alice")
|
||||||
@ -239,7 +256,7 @@ async def test_list_github_issues_filter_by_author(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_state(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_state(db_session, sample_issues):
|
||||||
"""Test filtering by state."""
|
"""Test filtering by state."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
open_results = await list_github_issues(state="open")
|
open_results = await list_github_issues(state="open")
|
||||||
@ -259,7 +276,7 @@ async def test_list_github_issues_filter_by_state(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_kind(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_kind(db_session, sample_issues):
|
||||||
"""Test filtering by kind (issue vs PR)."""
|
"""Test filtering by kind (issue vs PR)."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
issues = await list_github_issues(kind="issue")
|
issues = await list_github_issues(kind="issue")
|
||||||
@ -275,7 +292,7 @@ async def test_list_github_issues_filter_by_kind(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_labels(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_labels(db_session, sample_issues):
|
||||||
"""Test filtering by labels."""
|
"""Test filtering by labels."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(labels=["bug"])
|
results = await list_github_issues(labels=["bug"])
|
||||||
@ -287,7 +304,7 @@ async def test_list_github_issues_filter_by_labels(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_project_status(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_project_status(db_session, sample_issues):
|
||||||
"""Test filtering by project status."""
|
"""Test filtering by project status."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(project_status="In Progress")
|
results = await list_github_issues(project_status="In Progress")
|
||||||
@ -300,7 +317,7 @@ async def test_list_github_issues_filter_by_project_status(db_session, sample_is
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_project_field(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_project_field(db_session, sample_issues):
|
||||||
"""Test filtering by project field (JSONB)."""
|
"""Test filtering by project field (JSONB)."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(
|
results = await list_github_issues(
|
||||||
@ -316,7 +333,7 @@ async def test_list_github_issues_filter_by_project_field(db_session, sample_iss
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_updated_since(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_updated_since(db_session, sample_issues):
|
||||||
"""Test filtering by updated_since."""
|
"""Test filtering by updated_since."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=2)).isoformat()
|
since = (now - timedelta(days=2)).isoformat()
|
||||||
@ -332,7 +349,7 @@ async def test_list_github_issues_filter_by_updated_since(db_session, sample_iss
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_filter_by_updated_before(db_session, sample_issues):
|
async def test_list_github_issues_filter_by_updated_before(db_session, sample_issues):
|
||||||
"""Test filtering by updated_before (stale issues)."""
|
"""Test filtering by updated_before (stale issues)."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
before = (now - timedelta(days=30)).isoformat()
|
before = (now - timedelta(days=30)).isoformat()
|
||||||
@ -348,7 +365,7 @@ async def test_list_github_issues_filter_by_updated_before(db_session, sample_is
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_order_by_created(db_session, sample_issues):
|
async def test_list_github_issues_order_by_created(db_session, sample_issues):
|
||||||
"""Test ordering by created date."""
|
"""Test ordering by created date."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(order_by="created")
|
results = await list_github_issues(order_by="created")
|
||||||
@ -360,7 +377,7 @@ async def test_list_github_issues_order_by_created(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_limit(db_session, sample_issues):
|
async def test_list_github_issues_limit(db_session, sample_issues):
|
||||||
"""Test limiting results."""
|
"""Test limiting results."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(limit=2)
|
results = await list_github_issues(limit=2)
|
||||||
@ -371,7 +388,7 @@ async def test_list_github_issues_limit(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_limit_max_enforced(db_session, sample_issues):
|
async def test_list_github_issues_limit_max_enforced(db_session, sample_issues):
|
||||||
"""Test that limit is capped at 200."""
|
"""Test that limit is capped at 200."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
# Request 500 but should be capped at 200
|
# Request 500 but should be capped at 200
|
||||||
@ -384,7 +401,7 @@ async def test_list_github_issues_limit_max_enforced(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_combined_filters(db_session, sample_issues):
|
async def test_list_github_issues_combined_filters(db_session, sample_issues):
|
||||||
"""Test combining multiple filters."""
|
"""Test combining multiple filters."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(
|
results = await list_github_issues(
|
||||||
@ -402,7 +419,7 @@ async def test_list_github_issues_combined_filters(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_github_issues_url_construction(db_session, sample_issues):
|
async def test_list_github_issues_url_construction(db_session, sample_issues):
|
||||||
"""Test that URLs are correctly constructed."""
|
"""Test that URLs are correctly constructed."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(kind="issue", limit=1)
|
results = await list_github_issues(kind="issue", limit=1)
|
||||||
@ -425,7 +442,7 @@ async def test_list_github_issues_url_construction(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_issue_details_found(db_session, sample_issues):
|
async def test_github_issue_details_found(db_session, sample_issues):
|
||||||
"""Test getting details for an existing issue."""
|
"""Test getting details for an existing issue."""
|
||||||
from memory.api.MCP.github import github_issue_details
|
from memory.api.MCP.servers.github import github_issue_details
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_issue_details(repo="owner/repo1", number=1)
|
result = await github_issue_details(repo="owner/repo1", number=1)
|
||||||
@ -440,7 +457,7 @@ async def test_github_issue_details_found(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_issue_details_not_found(db_session, sample_issues):
|
async def test_github_issue_details_not_found(db_session, sample_issues):
|
||||||
"""Test getting details for a non-existent issue."""
|
"""Test getting details for a non-existent issue."""
|
||||||
from memory.api.MCP.github import github_issue_details
|
from memory.api.MCP.servers.github import github_issue_details
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
with pytest.raises(ValueError, match="not found"):
|
with pytest.raises(ValueError, match="not found"):
|
||||||
@ -450,7 +467,7 @@ async def test_github_issue_details_not_found(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_issue_details_pr(db_session, sample_issues):
|
async def test_github_issue_details_pr(db_session, sample_issues):
|
||||||
"""Test getting details for a PR."""
|
"""Test getting details for a PR."""
|
||||||
from memory.api.MCP.github import github_issue_details
|
from memory.api.MCP.servers.github import github_issue_details
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_issue_details(repo="owner/repo1", number=50)
|
result = await github_issue_details(repo="owner/repo1", number=50)
|
||||||
@ -468,7 +485,7 @@ async def test_github_issue_details_pr(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_work_summary_by_client(db_session, sample_issues):
|
async def test_github_work_summary_by_client(db_session, sample_issues):
|
||||||
"""Test work summary grouped by client."""
|
"""Test work summary grouped by client."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -489,7 +506,7 @@ async def test_github_work_summary_by_client(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_work_summary_by_status(db_session, sample_issues):
|
async def test_github_work_summary_by_status(db_session, sample_issues):
|
||||||
"""Test work summary grouped by status."""
|
"""Test work summary grouped by status."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -506,7 +523,7 @@ async def test_github_work_summary_by_status(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_work_summary_by_author(db_session, sample_issues):
|
async def test_github_work_summary_by_author(db_session, sample_issues):
|
||||||
"""Test work summary grouped by author."""
|
"""Test work summary grouped by author."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -522,7 +539,7 @@ async def test_github_work_summary_by_author(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_work_summary_by_repo(db_session, sample_issues):
|
async def test_github_work_summary_by_repo(db_session, sample_issues):
|
||||||
"""Test work summary grouped by repository."""
|
"""Test work summary grouped by repository."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -538,7 +555,7 @@ async def test_github_work_summary_by_repo(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_work_summary_with_until(db_session, sample_issues):
|
async def test_github_work_summary_with_until(db_session, sample_issues):
|
||||||
"""Test work summary with until date."""
|
"""Test work summary with until date."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -553,7 +570,7 @@ async def test_github_work_summary_with_until(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_work_summary_with_repo_filter(db_session, sample_issues):
|
async def test_github_work_summary_with_repo_filter(db_session, sample_issues):
|
||||||
"""Test work summary filtered by repository."""
|
"""Test work summary filtered by repository."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -571,7 +588,7 @@ async def test_github_work_summary_with_repo_filter(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_work_summary_invalid_group_by(db_session, sample_issues):
|
async def test_github_work_summary_invalid_group_by(db_session, sample_issues):
|
||||||
"""Test work summary with invalid group_by value."""
|
"""Test work summary with invalid group_by value."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -584,7 +601,7 @@ async def test_github_work_summary_invalid_group_by(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_work_summary_includes_sample_issues(db_session, sample_issues):
|
async def test_github_work_summary_includes_sample_issues(db_session, sample_issues):
|
||||||
"""Test that work summary includes sample issues."""
|
"""Test that work summary includes sample issues."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -610,7 +627,7 @@ async def test_github_work_summary_includes_sample_issues(db_session, sample_iss
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_repo_overview_basic(db_session, sample_issues):
|
async def test_github_repo_overview_basic(db_session, sample_issues):
|
||||||
"""Test basic repo overview."""
|
"""Test basic repo overview."""
|
||||||
from memory.api.MCP.github import github_repo_overview
|
from memory.api.MCP.servers.github import github_repo_overview
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_repo_overview(repo="owner/repo1")
|
result = await github_repo_overview(repo="owner/repo1")
|
||||||
@ -625,7 +642,7 @@ async def test_github_repo_overview_basic(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_repo_overview_counts(db_session, sample_issues):
|
async def test_github_repo_overview_counts(db_session, sample_issues):
|
||||||
"""Test repo overview counts are correct."""
|
"""Test repo overview counts are correct."""
|
||||||
from memory.api.MCP.github import github_repo_overview
|
from memory.api.MCP.servers.github import github_repo_overview
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_repo_overview(repo="owner/repo1")
|
result = await github_repo_overview(repo="owner/repo1")
|
||||||
@ -638,7 +655,7 @@ async def test_github_repo_overview_counts(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_repo_overview_status_breakdown(db_session, sample_issues):
|
async def test_github_repo_overview_status_breakdown(db_session, sample_issues):
|
||||||
"""Test repo overview includes status breakdown."""
|
"""Test repo overview includes status breakdown."""
|
||||||
from memory.api.MCP.github import github_repo_overview
|
from memory.api.MCP.servers.github import github_repo_overview
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_repo_overview(repo="owner/repo1")
|
result = await github_repo_overview(repo="owner/repo1")
|
||||||
@ -651,7 +668,7 @@ async def test_github_repo_overview_status_breakdown(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_repo_overview_top_assignees(db_session, sample_issues):
|
async def test_github_repo_overview_top_assignees(db_session, sample_issues):
|
||||||
"""Test repo overview includes top assignees."""
|
"""Test repo overview includes top assignees."""
|
||||||
from memory.api.MCP.github import github_repo_overview
|
from memory.api.MCP.servers.github import github_repo_overview
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_repo_overview(repo="owner/repo1")
|
result = await github_repo_overview(repo="owner/repo1")
|
||||||
@ -663,7 +680,7 @@ async def test_github_repo_overview_top_assignees(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_repo_overview_labels(db_session, sample_issues):
|
async def test_github_repo_overview_labels(db_session, sample_issues):
|
||||||
"""Test repo overview includes labels."""
|
"""Test repo overview includes labels."""
|
||||||
from memory.api.MCP.github import github_repo_overview
|
from memory.api.MCP.servers.github import github_repo_overview
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_repo_overview(repo="owner/repo1")
|
result = await github_repo_overview(repo="owner/repo1")
|
||||||
@ -676,7 +693,7 @@ async def test_github_repo_overview_labels(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_repo_overview_last_updated(db_session, sample_issues):
|
async def test_github_repo_overview_last_updated(db_session, sample_issues):
|
||||||
"""Test repo overview includes last updated timestamp."""
|
"""Test repo overview includes last updated timestamp."""
|
||||||
from memory.api.MCP.github import github_repo_overview
|
from memory.api.MCP.servers.github import github_repo_overview
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_repo_overview(repo="owner/repo1")
|
result = await github_repo_overview(repo="owner/repo1")
|
||||||
@ -688,7 +705,7 @@ async def test_github_repo_overview_last_updated(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_repo_overview_empty_repo(db_session):
|
async def test_github_repo_overview_empty_repo(db_session):
|
||||||
"""Test repo overview for a repo with no issues."""
|
"""Test repo overview for a repo with no issues."""
|
||||||
from memory.api.MCP.github import github_repo_overview
|
from memory.api.MCP.servers.github import github_repo_overview
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_repo_overview(repo="nonexistent/repo")
|
result = await github_repo_overview(repo="nonexistent/repo")
|
||||||
@ -704,7 +721,7 @@ async def test_github_repo_overview_empty_repo(db_session):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_github_issues_basic(db_session, sample_issues):
|
async def test_search_github_issues_basic(db_session, sample_issues):
|
||||||
"""Test basic search functionality."""
|
"""Test basic search functionality."""
|
||||||
from memory.api.MCP.github import search_github_issues
|
from memory.api.MCP.servers.github import search_github_issues
|
||||||
|
|
||||||
mock_search_result = Mock()
|
mock_search_result = Mock()
|
||||||
mock_search_result.id = sample_issues[0].id
|
mock_search_result.id = sample_issues[0].id
|
||||||
@ -723,7 +740,7 @@ async def test_search_github_issues_basic(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_github_issues_with_repo_filter(db_session, sample_issues):
|
async def test_search_github_issues_with_repo_filter(db_session, sample_issues):
|
||||||
"""Test search with repository filter."""
|
"""Test search with repository filter."""
|
||||||
from memory.api.MCP.github import search_github_issues
|
from memory.api.MCP.servers.github import search_github_issues
|
||||||
|
|
||||||
mock_search_result = Mock()
|
mock_search_result = Mock()
|
||||||
mock_search_result.id = sample_issues[0].id
|
mock_search_result.id = sample_issues[0].id
|
||||||
@ -745,7 +762,7 @@ async def test_search_github_issues_with_repo_filter(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_github_issues_with_state_filter(db_session, sample_issues):
|
async def test_search_github_issues_with_state_filter(db_session, sample_issues):
|
||||||
"""Test search with state filter."""
|
"""Test search with state filter."""
|
||||||
from memory.api.MCP.github import search_github_issues
|
from memory.api.MCP.servers.github import search_github_issues
|
||||||
|
|
||||||
mock_search_result = Mock()
|
mock_search_result = Mock()
|
||||||
mock_search_result.id = sample_issues[0].id
|
mock_search_result.id = sample_issues[0].id
|
||||||
@ -762,7 +779,7 @@ async def test_search_github_issues_with_state_filter(db_session, sample_issues)
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_github_issues_limit(db_session, sample_issues):
|
async def test_search_github_issues_limit(db_session, sample_issues):
|
||||||
"""Test search respects limit."""
|
"""Test search respects limit."""
|
||||||
from memory.api.MCP.github import search_github_issues
|
from memory.api.MCP.servers.github import search_github_issues
|
||||||
|
|
||||||
mock_results = [Mock(id=issue.id, score=0.9 - i * 0.1) for i, issue in enumerate(sample_issues[:3])]
|
mock_results = [Mock(id=issue.id, score=0.9 - i * 0.1) for i, issue in enumerate(sample_issues[:3])]
|
||||||
|
|
||||||
@ -779,7 +796,7 @@ async def test_search_github_issues_limit(db_session, sample_issues):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_github_issues_uses_github_modality(db_session, sample_issues):
|
async def test_search_github_issues_uses_github_modality(db_session, sample_issues):
|
||||||
"""Test that search uses github modality."""
|
"""Test that search uses github modality."""
|
||||||
from memory.api.MCP.github import search_github_issues
|
from memory.api.MCP.servers.github import search_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search:
|
with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search:
|
||||||
@ -797,7 +814,7 @@ async def test_search_github_issues_uses_github_modality(db_session, sample_issu
|
|||||||
|
|
||||||
def test_build_github_url_issue():
|
def test_build_github_url_issue():
|
||||||
"""Test URL construction for issues."""
|
"""Test URL construction for issues."""
|
||||||
from memory.api.MCP.github import _build_github_url
|
from memory.api.MCP.servers.github import _build_github_url
|
||||||
|
|
||||||
url = _build_github_url("owner/repo", 123, "issue")
|
url = _build_github_url("owner/repo", 123, "issue")
|
||||||
assert url == "https://github.com/owner/repo/issues/123"
|
assert url == "https://github.com/owner/repo/issues/123"
|
||||||
@ -805,7 +822,7 @@ def test_build_github_url_issue():
|
|||||||
|
|
||||||
def test_build_github_url_pr():
|
def test_build_github_url_pr():
|
||||||
"""Test URL construction for PRs."""
|
"""Test URL construction for PRs."""
|
||||||
from memory.api.MCP.github import _build_github_url
|
from memory.api.MCP.servers.github import _build_github_url
|
||||||
|
|
||||||
url = _build_github_url("owner/repo", 456, "pr")
|
url = _build_github_url("owner/repo", 456, "pr")
|
||||||
assert url == "https://github.com/owner/repo/pull/456"
|
assert url == "https://github.com/owner/repo/pull/456"
|
||||||
@ -813,7 +830,7 @@ def test_build_github_url_pr():
|
|||||||
|
|
||||||
def test_build_github_url_no_number():
|
def test_build_github_url_no_number():
|
||||||
"""Test URL construction without number."""
|
"""Test URL construction without number."""
|
||||||
from memory.api.MCP.github import _build_github_url
|
from memory.api.MCP.servers.github import _build_github_url
|
||||||
|
|
||||||
url = _build_github_url("owner/repo", None, "issue")
|
url = _build_github_url("owner/repo", None, "issue")
|
||||||
assert url == "https://github.com/owner/repo"
|
assert url == "https://github.com/owner/repo"
|
||||||
@ -821,7 +838,7 @@ def test_build_github_url_no_number():
|
|||||||
|
|
||||||
def test_serialize_issue_basic(db_session, sample_issues):
|
def test_serialize_issue_basic(db_session, sample_issues):
|
||||||
"""Test issue serialization."""
|
"""Test issue serialization."""
|
||||||
from memory.api.MCP.github import _serialize_issue
|
from memory.api.MCP.servers.github import _serialize_issue
|
||||||
|
|
||||||
issue = sample_issues[0]
|
issue = sample_issues[0]
|
||||||
result = _serialize_issue(issue)
|
result = _serialize_issue(issue)
|
||||||
@ -838,7 +855,7 @@ def test_serialize_issue_basic(db_session, sample_issues):
|
|||||||
|
|
||||||
def test_serialize_issue_with_content(db_session, sample_issues):
|
def test_serialize_issue_with_content(db_session, sample_issues):
|
||||||
"""Test issue serialization with content."""
|
"""Test issue serialization with content."""
|
||||||
from memory.api.MCP.github import _serialize_issue
|
from memory.api.MCP.servers.github import _serialize_issue
|
||||||
|
|
||||||
issue = sample_issues[0]
|
issue = sample_issues[0]
|
||||||
result = _serialize_issue(issue, include_content=True)
|
result = _serialize_issue(issue, include_content=True)
|
||||||
@ -865,7 +882,7 @@ async def test_list_github_issues_ordering(
|
|||||||
db_session, sample_issues, order_by, expected_first_number
|
db_session, sample_issues, order_by, expected_first_number
|
||||||
):
|
):
|
||||||
"""Test different ordering options."""
|
"""Test different ordering options."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(order_by=order_by)
|
results = await list_github_issues(order_by=order_by)
|
||||||
@ -880,7 +897,7 @@ async def test_list_github_issues_ordering(
|
|||||||
)
|
)
|
||||||
async def test_github_work_summary_all_group_by_options(db_session, sample_issues, group_by):
|
async def test_github_work_summary_all_group_by_options(db_session, sample_issues, group_by):
|
||||||
"""Test all valid group_by options."""
|
"""Test all valid group_by options."""
|
||||||
from memory.api.MCP.github import github_work_summary
|
from memory.api.MCP.servers.github import github_work_summary
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
since = (now - timedelta(days=30)).isoformat()
|
since = (now - timedelta(days=30)).isoformat()
|
||||||
@ -906,7 +923,7 @@ async def test_list_github_issues_label_filtering(
|
|||||||
db_session, sample_issues, labels, expected_count
|
db_session, sample_issues, labels, expected_count
|
||||||
):
|
):
|
||||||
"""Test various label filtering scenarios."""
|
"""Test various label filtering scenarios."""
|
||||||
from memory.api.MCP.github import list_github_issues
|
from memory.api.MCP.servers.github import list_github_issues
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
results = await list_github_issues(labels=labels)
|
results = await list_github_issues(labels=labels)
|
||||||
@ -1014,7 +1031,7 @@ def sample_pr_with_data(db_session):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_data):
|
async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_data):
|
||||||
"""Test that github_issue_details includes PR data for PRs."""
|
"""Test that github_issue_details includes PR data for PRs."""
|
||||||
from memory.api.MCP.github import github_issue_details
|
from memory.api.MCP.servers.github import github_issue_details
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_issue_details(repo="owner/repo1", number=999)
|
result = await github_issue_details(repo="owner/repo1", number=999)
|
||||||
@ -1034,7 +1051,7 @@ async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_issues):
|
async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_issues):
|
||||||
"""Test that github_issue_details does not include pr_data for issues."""
|
"""Test that github_issue_details does not include pr_data for issues."""
|
||||||
from memory.api.MCP.github import github_issue_details
|
from memory.api.MCP.servers.github import github_issue_details
|
||||||
|
|
||||||
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
with patch("memory.api.MCP.github.make_session", return_value=db_session):
|
||||||
result = await github_issue_details(repo="owner/repo1", number=1)
|
result = await github_issue_details(repo="owner/repo1", number=1)
|
||||||
@ -1045,7 +1062,7 @@ async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_iss
|
|||||||
|
|
||||||
def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data):
|
def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data):
|
||||||
"""Test that _serialize_issue includes pr_data when include_content=True."""
|
"""Test that _serialize_issue includes pr_data when include_content=True."""
|
||||||
from memory.api.MCP.github import _serialize_issue
|
from memory.api.MCP.servers.github import _serialize_issue
|
||||||
|
|
||||||
result = _serialize_issue(sample_pr_with_data, include_content=True)
|
result = _serialize_issue(sample_pr_with_data, include_content=True)
|
||||||
|
|
||||||
@ -1056,7 +1073,7 @@ def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data):
|
|||||||
|
|
||||||
def test_serialize_issue_no_pr_data_without_content(db_session, sample_pr_with_data):
|
def test_serialize_issue_no_pr_data_without_content(db_session, sample_pr_with_data):
|
||||||
"""Test that _serialize_issue excludes pr_data when include_content=False."""
|
"""Test that _serialize_issue excludes pr_data when include_content=False."""
|
||||||
from memory.api.MCP.github import _serialize_issue
|
from memory.api.MCP.servers.github import _serialize_issue
|
||||||
|
|
||||||
result = _serialize_issue(sample_pr_with_data, include_content=False)
|
result = _serialize_issue(sample_pr_with_data, include_content=False)
|
||||||
|
|
||||||
|
|||||||
@ -4,10 +4,27 @@ import sys
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, Mock, MagicMock, patch
|
from unittest.mock import AsyncMock, Mock, MagicMock, patch
|
||||||
|
|
||||||
# Mock the mcp module and all its submodules before importing anything that uses it
|
# Mock FastMCP - this creates a decorator factory that passes through the function unchanged
|
||||||
|
class MockFastMCP:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def tool(self):
|
||||||
|
def decorator(func):
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the fastmcp module before importing anything that uses it
|
||||||
|
_mock_fastmcp = MagicMock()
|
||||||
|
_mock_fastmcp.FastMCP = MockFastMCP
|
||||||
|
sys.modules["fastmcp"] = _mock_fastmcp
|
||||||
|
|
||||||
|
# Mock the mcp module and all its submodules
|
||||||
_mock_mcp = MagicMock()
|
_mock_mcp = MagicMock()
|
||||||
_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
|
_mock_mcp.tool = lambda: lambda f: f
|
||||||
sys.modules["mcp"] = _mock_mcp
|
sys.modules["mcp"] = _mock_mcp
|
||||||
|
sys.modules["mcp.types"] = MagicMock()
|
||||||
sys.modules["mcp.server"] = MagicMock()
|
sys.modules["mcp.server"] = MagicMock()
|
||||||
sys.modules["mcp.server.auth"] = MagicMock()
|
sys.modules["mcp.server.auth"] = MagicMock()
|
||||||
sys.modules["mcp.server.auth.handlers"] = MagicMock()
|
sys.modules["mcp.server.auth.handlers"] = MagicMock()
|
||||||
@ -20,7 +37,7 @@ sys.modules["mcp.server.fastmcp.server"] = MagicMock()
|
|||||||
# Also mock the memory.api.MCP.base module to avoid MCP imports
|
# Also mock the memory.api.MCP.base module to avoid MCP imports
|
||||||
_mock_base = MagicMock()
|
_mock_base = MagicMock()
|
||||||
_mock_base.mcp = MagicMock()
|
_mock_base.mcp = MagicMock()
|
||||||
_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
|
_mock_base.mcp.tool = lambda: lambda f: f
|
||||||
sys.modules["memory.api.MCP.base"] = _mock_base
|
sys.modules["memory.api.MCP.base"] = _mock_base
|
||||||
|
|
||||||
from memory.common.db.models import Person
|
from memory.common.db.models import Person
|
||||||
@ -97,7 +114,7 @@ def sample_people(db_session):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_person_success(db_session):
|
async def test_add_person_success(db_session):
|
||||||
"""Test adding a new person."""
|
"""Test adding a new person."""
|
||||||
from memory.api.MCP.people import add_person
|
from memory.api.MCP.servers.people import add_person
|
||||||
|
|
||||||
mock_task = Mock()
|
mock_task = Mock()
|
||||||
mock_task.id = "task-123"
|
mock_task.id = "task-123"
|
||||||
@ -128,7 +145,7 @@ async def test_add_person_success(db_session):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_person_already_exists(db_session, sample_people):
|
async def test_add_person_already_exists(db_session, sample_people):
|
||||||
"""Test adding a person that already exists."""
|
"""Test adding a person that already exists."""
|
||||||
from memory.api.MCP.people import add_person
|
from memory.api.MCP.servers.people import add_person
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
with pytest.raises(ValueError, match="already exists"):
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
@ -141,7 +158,7 @@ async def test_add_person_already_exists(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_person_minimal(db_session):
|
async def test_add_person_minimal(db_session):
|
||||||
"""Test adding a person with minimal data."""
|
"""Test adding a person with minimal data."""
|
||||||
from memory.api.MCP.people import add_person
|
from memory.api.MCP.servers.people import add_person
|
||||||
|
|
||||||
mock_task = Mock()
|
mock_task = Mock()
|
||||||
mock_task.id = "task-456"
|
mock_task.id = "task-456"
|
||||||
@ -166,7 +183,7 @@ async def test_add_person_minimal(db_session):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_person_info_success(db_session, sample_people):
|
async def test_update_person_info_success(db_session, sample_people):
|
||||||
"""Test updating a person's info."""
|
"""Test updating a person's info."""
|
||||||
from memory.api.MCP.people import update_person_info
|
from memory.api.MCP.servers.people import update_person_info
|
||||||
|
|
||||||
mock_task = Mock()
|
mock_task = Mock()
|
||||||
mock_task.id = "task-789"
|
mock_task.id = "task-789"
|
||||||
@ -188,7 +205,7 @@ async def test_update_person_info_success(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_person_info_not_found(db_session, sample_people):
|
async def test_update_person_info_not_found(db_session, sample_people):
|
||||||
"""Test updating a person that doesn't exist."""
|
"""Test updating a person that doesn't exist."""
|
||||||
from memory.api.MCP.people import update_person_info
|
from memory.api.MCP.servers.people import update_person_info
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
with pytest.raises(ValueError, match="not found"):
|
with pytest.raises(ValueError, match="not found"):
|
||||||
@ -201,7 +218,7 @@ async def test_update_person_info_not_found(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_person_info_with_merge_params(db_session, sample_people):
|
async def test_update_person_info_with_merge_params(db_session, sample_people):
|
||||||
"""Test that update passes all merge parameters."""
|
"""Test that update passes all merge parameters."""
|
||||||
from memory.api.MCP.people import update_person_info
|
from memory.api.MCP.servers.people import update_person_info
|
||||||
|
|
||||||
mock_task = Mock()
|
mock_task = Mock()
|
||||||
mock_task.id = "task-merge"
|
mock_task.id = "task-merge"
|
||||||
@ -234,7 +251,7 @@ async def test_update_person_info_with_merge_params(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_person_found(db_session, sample_people):
|
async def test_get_person_found(db_session, sample_people):
|
||||||
"""Test getting a person that exists."""
|
"""Test getting a person that exists."""
|
||||||
from memory.api.MCP.people import get_person
|
from memory.api.MCP.servers.people import get_person
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
result = await get_person(identifier="alice_chen")
|
result = await get_person(identifier="alice_chen")
|
||||||
@ -251,7 +268,7 @@ async def test_get_person_found(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_person_not_found(db_session, sample_people):
|
async def test_get_person_not_found(db_session, sample_people):
|
||||||
"""Test getting a person that doesn't exist."""
|
"""Test getting a person that doesn't exist."""
|
||||||
from memory.api.MCP.people import get_person
|
from memory.api.MCP.servers.people import get_person
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
result = await get_person(identifier="nonexistent_person")
|
result = await get_person(identifier="nonexistent_person")
|
||||||
@ -267,7 +284,7 @@ async def test_get_person_not_found(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_people_no_filters(db_session, sample_people):
|
async def test_list_people_no_filters(db_session, sample_people):
|
||||||
"""Test listing all people without filters."""
|
"""Test listing all people without filters."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
results = await list_people()
|
results = await list_people()
|
||||||
@ -282,7 +299,7 @@ async def test_list_people_no_filters(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_people_filter_by_tags(db_session, sample_people):
|
async def test_list_people_filter_by_tags(db_session, sample_people):
|
||||||
"""Test filtering by tags."""
|
"""Test filtering by tags."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
results = await list_people(tags=["work"])
|
results = await list_people(tags=["work"])
|
||||||
@ -294,7 +311,7 @@ async def test_list_people_filter_by_tags(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_people_filter_by_search(db_session, sample_people):
|
async def test_list_people_filter_by_search(db_session, sample_people):
|
||||||
"""Test filtering by search term."""
|
"""Test filtering by search term."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
results = await list_people(search="alice")
|
results = await list_people(search="alice")
|
||||||
@ -306,7 +323,7 @@ async def test_list_people_filter_by_search(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_people_search_in_notes(db_session, sample_people):
|
async def test_list_people_search_in_notes(db_session, sample_people):
|
||||||
"""Test that search works on notes content."""
|
"""Test that search works on notes content."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
results = await list_people(search="climbing")
|
results = await list_people(search="climbing")
|
||||||
@ -318,7 +335,7 @@ async def test_list_people_search_in_notes(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_people_limit(db_session, sample_people):
|
async def test_list_people_limit(db_session, sample_people):
|
||||||
"""Test limiting results."""
|
"""Test limiting results."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
results = await list_people(limit=1)
|
results = await list_people(limit=1)
|
||||||
@ -329,7 +346,7 @@ async def test_list_people_limit(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_people_limit_max_enforced(db_session, sample_people):
|
async def test_list_people_limit_max_enforced(db_session, sample_people):
|
||||||
"""Test that limit is capped at 200."""
|
"""Test that limit is capped at 200."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
# Request 500 but should be capped at 200
|
# Request 500 but should be capped at 200
|
||||||
@ -342,7 +359,7 @@ async def test_list_people_limit_max_enforced(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_people_combined_filters(db_session, sample_people):
|
async def test_list_people_combined_filters(db_session, sample_people):
|
||||||
"""Test combining tag and search filters."""
|
"""Test combining tag and search filters."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
results = await list_people(tags=["work"], search="chen")
|
results = await list_people(tags=["work"], search="chen")
|
||||||
@ -359,7 +376,7 @@ async def test_list_people_combined_filters(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_person_success(db_session, sample_people):
|
async def test_delete_person_success(db_session, sample_people):
|
||||||
"""Test deleting a person."""
|
"""Test deleting a person."""
|
||||||
from memory.api.MCP.people import delete_person
|
from memory.api.MCP.servers.people import delete_person
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
result = await delete_person(identifier="bob_smith")
|
result = await delete_person(identifier="bob_smith")
|
||||||
@ -376,7 +393,7 @@ async def test_delete_person_success(db_session, sample_people):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_person_not_found(db_session, sample_people):
|
async def test_delete_person_not_found(db_session, sample_people):
|
||||||
"""Test deleting a person that doesn't exist."""
|
"""Test deleting a person that doesn't exist."""
|
||||||
from memory.api.MCP.people import delete_person
|
from memory.api.MCP.servers.people import delete_person
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
with pytest.raises(ValueError, match="not found"):
|
with pytest.raises(ValueError, match="not found"):
|
||||||
@ -390,7 +407,7 @@ async def test_delete_person_not_found(db_session, sample_people):
|
|||||||
|
|
||||||
def test_person_to_dict(sample_people):
|
def test_person_to_dict(sample_people):
|
||||||
"""Test the _person_to_dict helper function."""
|
"""Test the _person_to_dict helper function."""
|
||||||
from memory.api.MCP.people import _person_to_dict
|
from memory.api.MCP.servers.people import _person_to_dict
|
||||||
|
|
||||||
person = sample_people[0]
|
person = sample_people[0]
|
||||||
result = _person_to_dict(person)
|
result = _person_to_dict(person)
|
||||||
@ -406,7 +423,7 @@ def test_person_to_dict(sample_people):
|
|||||||
|
|
||||||
def test_person_to_dict_empty_fields(db_session):
|
def test_person_to_dict_empty_fields(db_session):
|
||||||
"""Test _person_to_dict with empty optional fields."""
|
"""Test _person_to_dict with empty optional fields."""
|
||||||
from memory.api.MCP.people import _person_to_dict
|
from memory.api.MCP.servers.people import _person_to_dict
|
||||||
|
|
||||||
person = Person(
|
person = Person(
|
||||||
identifier="empty_person",
|
identifier="empty_person",
|
||||||
@ -448,7 +465,7 @@ def test_person_to_dict_empty_fields(db_session):
|
|||||||
)
|
)
|
||||||
async def test_list_people_various_tags(db_session, sample_people, tag, expected_count):
|
async def test_list_people_various_tags(db_session, sample_people, tag, expected_count):
|
||||||
"""Test filtering by various tags."""
|
"""Test filtering by various tags."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
results = await list_people(tags=[tag])
|
results = await list_people(tags=[tag])
|
||||||
@ -473,7 +490,7 @@ async def test_list_people_various_searches(
|
|||||||
db_session, sample_people, search_term, expected_identifiers
|
db_session, sample_people, search_term, expected_identifiers
|
||||||
):
|
):
|
||||||
"""Test search with various terms."""
|
"""Test search with various terms."""
|
||||||
from memory.api.MCP.people import list_people
|
from memory.api.MCP.servers.people import list_people
|
||||||
|
|
||||||
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
with patch("memory.api.MCP.people.make_session", return_value=db_session):
|
||||||
results = await list_people(search=search_term)
|
results = await list_people(search=search_term)
|
||||||
|
|||||||
@ -69,8 +69,12 @@ class TestSearchResult:
|
|||||||
result = SearchResult.from_source_item(source, chunks)
|
result = SearchResult.from_source_item(source, chunks)
|
||||||
assert result.search_score == 0.9
|
assert result.search_score == 0.9
|
||||||
|
|
||||||
def test_search_score_multiple_chunks_uses_mean(self):
|
def test_search_score_multiple_chunks_uses_max(self):
|
||||||
"""Multiple chunks should use mean of relevance scores, not sum."""
|
"""Multiple chunks should use max of relevance scores.
|
||||||
|
|
||||||
|
Using max finds documents with at least one highly relevant section,
|
||||||
|
which is better for 'half-remembered' searches where users recall one detail.
|
||||||
|
"""
|
||||||
source = self._make_source_item()
|
source = self._make_source_item()
|
||||||
chunks = [
|
chunks = [
|
||||||
self._make_chunk(0.9),
|
self._make_chunk(0.9),
|
||||||
@ -79,8 +83,8 @@ class TestSearchResult:
|
|||||||
]
|
]
|
||||||
|
|
||||||
result = SearchResult.from_source_item(source, chunks)
|
result = SearchResult.from_source_item(source, chunks)
|
||||||
# Mean of 0.9, 0.7, 0.8 = 0.8
|
# Max of 0.9, 0.7, 0.8 = 0.9
|
||||||
assert result.search_score == pytest.approx(0.8)
|
assert result.search_score == pytest.approx(0.9)
|
||||||
|
|
||||||
def test_search_score_empty_chunks(self):
|
def test_search_score_empty_chunks(self):
|
||||||
"""Empty chunk list should result in None or 0 score."""
|
"""Empty chunk list should result in None or 0 score."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user