exclude stuff

This commit is contained in:
Daniel O'Connell 2025-12-29 17:03:08 +01:00
parent 59c45ff1fb
commit 9cf71c9336
12 changed files with 546 additions and 86 deletions

View File

@ -0,0 +1,35 @@
"""Add exclude_folder_ids to google_folders
Revision ID: add_exclude_folder_ids
Revises: 20251229_120000_add_google_drive
Create Date: 2025-12-29 15:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "add_exclude_folder_ids"
down_revision: Union[str, None] = "f2a3b4c5d6e7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"google_folders",
sa.Column(
"exclude_folder_ids",
postgresql.ARRAY(sa.Text()),
nullable=False,
server_default="{}",
),
)
def downgrade() -> None:
op.drop_column("google_folders", "exclude_folder_ids")

View File

@ -1935,4 +1935,98 @@ a.folder-item-name:hover {
.add-btn.secondary:hover { .add-btn.secondary:hover {
background: #edf2f7; background: #edf2f7;
}
/* === Exclusion Browser === */
.exclusion-browser {
display: flex;
flex-direction: column;
min-height: 400px;
max-height: 60vh;
}
.excluded-list {
background: #fffbeb;
border: 1px solid #f59e0b;
border-radius: 6px;
padding: 0.75rem;
margin-bottom: 0.75rem;
}
.excluded-list h5 {
font-size: 0.85rem;
color: #92400e;
margin-bottom: 0.5rem;
}
.excluded-items {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
}
.excluded-item {
display: flex;
align-items: center;
gap: 0.25rem;
background: #fef3c7;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.85rem;
}
.excluded-name {
color: #92400e;
}
.remove-exclusion-btn {
background: none;
border: none;
color: #b45309;
cursor: pointer;
padding: 0 0.25rem;
font-size: 0.75rem;
line-height: 1;
}
.remove-exclusion-btn:hover {
color: #dc2626;
}
.folder-item.excluded {
background: #fffbeb;
border-color: #f59e0b;
}
.exclusion-badge {
background: #f59e0b;
color: white;
padding: 0.125rem 0.375rem;
border-radius: 3px;
font-size: 0.7rem;
margin-left: auto;
margin-right: 0.5rem;
}
.exclusions-btn {
background: #fffbeb;
color: #b45309;
border: 1px solid #f59e0b;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.75rem;
cursor: pointer;
transition: all 0.15s ease;
}
.exclusions-btn:hover {
background: #fef3c7;
color: #92400e;
}
.setting-badge.warning {
background: #fffbeb;
color: #b45309;
border: 1px solid #f59e0b;
} }

View File

@ -1,6 +1,6 @@
import { useState, useEffect, useCallback } from 'react' import { useState, useEffect, useCallback } from 'react'
import { Link } from 'react-router-dom' import { Link } from 'react-router-dom'
import { useSources, EmailAccount, ArticleFeed, GithubAccount, GoogleAccount, GoogleOAuthConfig, DriveItem, BrowseResponse, GoogleFolderCreate } from '@/hooks/useSources' import { useSources, EmailAccount, ArticleFeed, GithubAccount, GoogleAccount, GoogleFolder, GoogleOAuthConfig, DriveItem, BrowseResponse, GoogleFolderCreate } from '@/hooks/useSources'
import { import {
SourceCard, SourceCard,
Modal, Modal,
@ -956,6 +956,7 @@ const GoogleDrivePanel = () => {
const [error, setError] = useState<string | null>(null) const [error, setError] = useState<string | null>(null)
const [addingFolderTo, setAddingFolderTo] = useState<number | null>(null) const [addingFolderTo, setAddingFolderTo] = useState<number | null>(null)
const [browsingFoldersFor, setBrowsingFoldersFor] = useState<number | null>(null) const [browsingFoldersFor, setBrowsingFoldersFor] = useState<number | null>(null)
const [managingExclusionsFor, setManagingExclusionsFor] = useState<{ accountId: number; folder: GoogleFolder } | null>(null)
const [uploadingConfig, setUploadingConfig] = useState(false) const [uploadingConfig, setUploadingConfig] = useState(false)
const loadData = useCallback(async () => { const loadData = useCallback(async () => {
@ -1044,6 +1045,12 @@ const GoogleDrivePanel = () => {
loadData() loadData()
} }
const handleUpdateExclusions = async (accountId: number, folderId: number, excludeIds: string[]) => {
await updateGoogleFolder(accountId, folderId, { exclude_folder_ids: excludeIds })
setManagingExclusionsFor(null)
loadData()
}
const handleDeleteFolder = async (accountId: number, folderId: number) => { const handleDeleteFolder = async (accountId: number, folderId: number) => {
await deleteGoogleFolder(accountId, folderId) await deleteGoogleFolder(accountId, folderId)
loadData() loadData()
@ -1188,6 +1195,11 @@ const GoogleDrivePanel = () => {
<div className="folder-settings"> <div className="folder-settings">
{folder.recursive && <span className="setting-badge">Recursive</span>} {folder.recursive && <span className="setting-badge">Recursive</span>}
{folder.include_shared && <span className="setting-badge">Shared</span>} {folder.include_shared && <span className="setting-badge">Shared</span>}
{folder.exclude_folder_ids.length > 0 && (
<span className="setting-badge warning">
{folder.exclude_folder_ids.length} excluded
</span>
)}
</div> </div>
<div className="folder-actions"> <div className="folder-actions">
<SyncButton <SyncButton
@ -1195,6 +1207,14 @@ const GoogleDrivePanel = () => {
disabled={!folder.active || !account.active} disabled={!folder.active || !account.active}
label="Sync" label="Sync"
/> />
{folder.recursive && (
<button
className="exclusions-btn"
onClick={() => setManagingExclusionsFor({ accountId: account.id, folder })}
>
Exclusions
</button>
)}
<button <button
className="toggle-btn" className="toggle-btn"
onClick={() => handleToggleFolderActive(account.id, folder.id, folder.active)} onClick={() => handleToggleFolderActive(account.id, folder.id, folder.active)}
@ -1233,6 +1253,15 @@ const GoogleDrivePanel = () => {
onCancel={() => setBrowsingFoldersFor(null)} onCancel={() => setBrowsingFoldersFor(null)}
/> />
)} )}
{managingExclusionsFor && (
<ExclusionBrowser
accountId={managingExclusionsFor.accountId}
folder={managingExclusionsFor.folder}
onSave={(excludeIds) => handleUpdateExclusions(managingExclusionsFor.accountId, managingExclusionsFor.folder.id, excludeIds)}
onCancel={() => setManagingExclusionsFor(null)}
/>
)}
</div> </div>
) )
} }
@ -1422,6 +1451,195 @@ const FolderBrowser = ({ accountId, onSelect, onCancel }: FolderBrowserProps) =>
) )
} }
// === Exclusion Browser ===
interface ExclusionBrowserProps {
accountId: number
folder: GoogleFolder
onSave: (excludeIds: string[]) => void
onCancel: () => void
}
interface ExcludedFolder {
id: string
name: string
path: string
}
const ExclusionBrowser = ({ accountId, folder, onSave, onCancel }: ExclusionBrowserProps) => {
const { browseGoogleDrive } = useSources()
const [path, setPath] = useState<PathItem[]>([{ id: folder.folder_id, name: folder.folder_name }])
const [items, setItems] = useState<DriveItem[]>([])
const [excluded, setExcluded] = useState<Map<string, ExcludedFolder>>(() => {
// Initialize with current exclusions (we don't have names, so use ID as name)
const map = new Map<string, ExcludedFolder>()
for (const id of folder.exclude_folder_ids) {
map.set(id, { id, name: id, path: '(previously excluded)' })
}
return map
})
const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
const currentFolderId = path[path.length - 1].id
const currentPath = path.map(p => p.name).join(' > ')
const loadFolder = useCallback(async (folderId: string) => {
setLoading(true)
setError(null)
try {
const response = await browseGoogleDrive(accountId, folderId)
// Only show folders, not files
setItems(response.items.filter(item => item.is_folder))
} catch (e) {
setError(e instanceof Error ? e.message : 'Failed to load folder')
} finally {
setLoading(false)
}
}, [accountId, browseGoogleDrive])
useEffect(() => {
loadFolder(currentFolderId)
}, [currentFolderId, loadFolder])
const navigateToFolder = (item: DriveItem) => {
setPath([...path, { id: item.id, name: item.name }])
}
const navigateToPathIndex = (index: number) => {
setPath(path.slice(0, index + 1))
}
const toggleExclude = (item: DriveItem) => {
const newExcluded = new Map(excluded)
if (newExcluded.has(item.id)) {
newExcluded.delete(item.id)
} else {
newExcluded.set(item.id, {
id: item.id,
name: item.name,
path: currentPath + ' > ' + item.name,
})
}
setExcluded(newExcluded)
}
const removeExclusion = (id: string) => {
const newExcluded = new Map(excluded)
newExcluded.delete(id)
setExcluded(newExcluded)
}
const handleSave = () => {
onSave(Array.from(excluded.keys()))
}
return (
<Modal title={`Manage Exclusions: ${folder.folder_name}`} onClose={onCancel}>
<div className="exclusion-browser">
{/* Current exclusions */}
{excluded.size > 0 && (
<div className="excluded-list">
<h5>Excluded Folders ({excluded.size})</h5>
<div className="excluded-items">
{Array.from(excluded.values()).map(item => (
<div key={item.id} className="excluded-item">
<span className="excluded-name" title={item.path}>
📁 {item.name}
</span>
<button
className="remove-exclusion-btn"
onClick={() => removeExclusion(item.id)}
title="Remove exclusion"
>
</button>
</div>
))}
</div>
</div>
)}
{/* Breadcrumb */}
<div className="folder-breadcrumb">
{path.map((item, index) => (
<span key={item.id}>
{index > 0 && <span className="breadcrumb-sep">&gt;</span>}
<button
className={`breadcrumb-item ${index === path.length - 1 ? 'current' : ''}`}
onClick={() => navigateToPathIndex(index)}
disabled={index === path.length - 1}
>
{item.name}
</button>
</span>
))}
</div>
{/* Content */}
{error && <div className="form-error">{error}</div>}
{loading ? (
<div className="folder-loading">Loading...</div>
) : items.length === 0 ? (
<div className="folder-empty">No subfolders in this folder</div>
) : (
<div className="folder-list">
{items.map(item => (
<div key={item.id} className={`folder-item ${excluded.has(item.id) ? 'excluded' : ''}`}>
<label className="folder-item-checkbox">
<input
type="checkbox"
checked={excluded.has(item.id)}
onChange={() => toggleExclude(item)}
/>
</label>
<span className="folder-item-icon">📁</span>
<a
className="folder-item-name"
href={`https://drive.google.com/drive/folders/${item.id}`}
target="_blank"
rel="noopener noreferrer"
onClick={(e) => e.stopPropagation()}
>
{item.name}
</a>
{excluded.has(item.id) && (
<span className="exclusion-badge">Excluded</span>
)}
<button
className="folder-item-enter"
onClick={() => navigateToFolder(item)}
title="Browse subfolder"
>
</button>
</div>
))}
</div>
)}
{/* Footer */}
<div className="folder-browser-footer">
<span className="selected-count">
{excluded.size} folder{excluded.size !== 1 ? 's' : ''} excluded
</span>
<div className="folder-browser-actions">
<button type="button" className="cancel-btn" onClick={onCancel}>Cancel</button>
<button
type="button"
className="submit-btn"
onClick={handleSave}
>
Save Exclusions
</button>
</div>
</div>
</div>
</Modal>
)
}
interface GoogleFolderFormProps { interface GoogleFolderFormProps {
accountId: number accountId: number
onSubmit: (data: any) => Promise<void> onSubmit: (data: any) => Promise<void>

View File

@ -177,6 +177,7 @@ export interface GoogleFolder {
check_interval: number check_interval: number
last_sync_at: string | null last_sync_at: string | null
active: boolean active: boolean
exclude_folder_ids: string[]
} }
export interface GoogleAccount { export interface GoogleAccount {
@ -205,6 +206,7 @@ export interface GoogleFolderUpdate {
tags?: string[] tags?: string[]
check_interval?: number check_interval?: number
active?: boolean active?: boolean
exclude_folder_ids?: string[]
} }
// Types for Google Drive browsing // Types for Google Drive browsing

View File

@ -29,6 +29,9 @@ from memory.common.db.models import (
ScheduledLLMCall, ScheduledLLMCall,
SourceItem, SourceItem,
User, User,
GoogleDoc,
GoogleFolder,
GoogleAccount,
) )
from memory.common.db.models.discord import DiscordChannel, DiscordServer, DiscordUser from memory.common.db.models.discord import DiscordChannel, DiscordServer, DiscordUser
@ -394,6 +397,42 @@ class GithubItemAdmin(ModelView, model=GithubItem):
column_sortable_list = ["github_updated_at", "created_at"] column_sortable_list = ["github_updated_at", "created_at"]
class GoogleDocAdmin(ModelView, model=GoogleDoc):
column_list = source_columns(
GoogleDoc,
"title",
"folder_path",
"owner",
"last_modified_by",
"word_count",
"content_hash",
)
column_searchable_list = ["title", "folder_path", "owner", "last_modified_by", "id"]
column_sortable_list = ["google_modified_at", "created_at"]
class GoogleFolderAdmin(ModelView, model=GoogleFolder):
column_list = source_columns(
GoogleFolder, "folder_name", "folder_path", "account", "active"
)
column_searchable_list = ["folder_name", "folder_path", "id"]
column_sortable_list = ["last_sync_at", "created_at"]
class GoogleAccountAdmin(ModelView, model=GoogleAccount):
column_list = [
"id",
"name",
"email",
"active",
"last_sync_at",
"created_at",
"updated_at",
]
column_searchable_list = ["name", "email", "id"]
column_sortable_list = ["last_sync_at", "created_at"]
def setup_admin(admin: Admin): def setup_admin(admin: Admin):
"""Add all admin views to the admin instance with OAuth protection.""" """Add all admin views to the admin instance with OAuth protection."""
admin.add_view(SourceItemAdmin) admin.add_view(SourceItemAdmin)
@ -421,3 +460,6 @@ def setup_admin(admin: Admin):
admin.add_view(GithubAccountAdmin) admin.add_view(GithubAccountAdmin)
admin.add_view(GithubRepoAdmin) admin.add_view(GithubRepoAdmin)
admin.add_view(GithubItemAdmin) admin.add_view(GithubItemAdmin)
admin.add_view(GoogleDocAdmin)
admin.add_view(GoogleFolderAdmin)
admin.add_view(GoogleAccountAdmin)

View File

@ -60,6 +60,7 @@ class FolderUpdate(BaseModel):
tags: list[str] | None = None tags: list[str] | None = None
check_interval: int | None = None check_interval: int | None = None
active: bool | None = None active: bool | None = None
exclude_folder_ids: list[str] | None = None
class FolderResponse(BaseModel): class FolderResponse(BaseModel):
@ -73,6 +74,7 @@ class FolderResponse(BaseModel):
check_interval: int check_interval: int
last_sync_at: str | None last_sync_at: str | None
active: bool active: bool
exclude_folder_ids: list[str]
class AccountResponse(BaseModel): class AccountResponse(BaseModel):
@ -370,6 +372,7 @@ def list_accounts(
folder.last_sync_at.isoformat() if folder.last_sync_at else None folder.last_sync_at.isoformat() if folder.last_sync_at else None
), ),
active=cast(bool, folder.active), active=cast(bool, folder.active),
exclude_folder_ids=cast(list[str], folder.exclude_folder_ids) or [],
) )
for folder in account.folders for folder in account.folders
], ],
@ -531,6 +534,7 @@ def add_folder(
check_interval=cast(int, new_folder.check_interval), check_interval=cast(int, new_folder.check_interval),
last_sync_at=None, last_sync_at=None,
active=cast(bool, new_folder.active), active=cast(bool, new_folder.active),
exclude_folder_ids=[],
) )
@ -567,6 +571,8 @@ def update_folder(
folder.check_interval = updates.check_interval folder.check_interval = updates.check_interval
if updates.active is not None: if updates.active is not None:
folder.active = updates.active folder.active = updates.active
if updates.exclude_folder_ids is not None:
folder.exclude_folder_ids = updates.exclude_folder_ids
db.commit() db.commit()
db.refresh(folder) db.refresh(folder)
@ -584,6 +590,7 @@ def update_folder(
folder.last_sync_at.isoformat() if folder.last_sync_at else None folder.last_sync_at.isoformat() if folder.last_sync_at else None
), ),
active=cast(bool, folder.active), active=cast(bool, folder.active),
exclude_folder_ids=cast(list[str], folder.exclude_folder_ids) or [],
) )

View File

@ -350,6 +350,9 @@ class GoogleFolder(Base):
# File type filters (empty = all text documents) # File type filters (empty = all text documents)
mime_type_filter = Column(ARRAY(Text), nullable=False, server_default="{}") mime_type_filter = Column(ARRAY(Text), nullable=False, server_default="{}")
# Excluded subfolder IDs (skip these when syncing recursively)
exclude_folder_ids = Column(ARRAY(Text), nullable=False, server_default="{}")
# Tags to apply to all documents from this folder # Tags to apply to all documents from this folder
tags = Column(ARRAY(Text), nullable=False, server_default="{}") tags = Column(ARRAY(Text), nullable=False, server_default="{}")

View File

@ -130,9 +130,19 @@ class GoogleDriveClient:
recursive: bool = True, recursive: bool = True,
since: datetime | None = None, since: datetime | None = None,
page_size: int = 100, page_size: int = 100,
exclude_folder_ids: set[str] | None = None,
) -> Generator[dict, None, None]: ) -> Generator[dict, None, None]:
"""List all supported files in a folder with pagination.""" """List all supported files in a folder with pagination.
Args:
folder_id: The Google Drive folder ID to list
recursive: Whether to recurse into subfolders
since: Only return files modified after this time
page_size: Number of files per API page
exclude_folder_ids: Set of folder IDs to skip during recursive traversal
"""
service = self._get_service() service = self._get_service()
exclude_folder_ids = exclude_folder_ids or set()
# Build query for supported file types # Build query for supported file types
all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES
@ -167,11 +177,16 @@ class GoogleDriveClient:
for file in response.get("files", []): for file in response.get("files", []):
if file["mimeType"] == "application/vnd.google-apps.folder": if file["mimeType"] == "application/vnd.google-apps.folder":
if recursive: if recursive and file["id"] not in exclude_folder_ids:
# Recursively list files in subfolder # Recursively list files in subfolder
yield from self.list_files_in_folder( yield from self.list_files_in_folder(
file["id"], recursive=True, since=since file["id"],
recursive=True,
since=since,
exclude_folder_ids=exclude_folder_ids,
) )
elif file["id"] in exclude_folder_ids:
logger.info(f"Skipping excluded folder: {file['name']} ({file['id']})")
else: else:
yield file yield file

View File

@ -251,10 +251,16 @@ def sync_google_folder(folder_id: int, force_full: bool = False) -> dict[str, An
# It's a folder - list and sync all files inside # It's a folder - list and sync all files inside
folder_path = client.get_folder_path(google_id) folder_path = client.get_folder_path(google_id)
# Get excluded folder IDs
exclude_ids = set(cast(list[str], folder.exclude_folder_ids) or [])
if exclude_ids:
logger.info(f"Excluding {len(exclude_ids)} folder(s) from sync")
for file_meta in client.list_files_in_folder( for file_meta in client.list_files_in_folder(
google_id, google_id,
recursive=cast(bool, folder.recursive), recursive=cast(bool, folder.recursive),
since=since, since=since,
exclude_folder_ids=exclude_ids,
): ):
try: try:
file_data = client.fetch_file(file_meta, folder_path) file_data = client.fetch_file(file_meta, folder_path)

View File

@ -5,10 +5,27 @@ import pytest
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, Mock, MagicMock, patch from unittest.mock import AsyncMock, Mock, MagicMock, patch
# Mock the mcp module and all its submodules before importing anything that uses it # Mock FastMCP - this creates a decorator factory that passes through the function unchanged
class MockFastMCP:
def __init__(self, name):
self.name = name
def tool(self):
def decorator(func):
return func
return decorator
# Mock the fastmcp module before importing anything that uses it
_mock_fastmcp = MagicMock()
_mock_fastmcp.FastMCP = MockFastMCP
sys.modules["fastmcp"] = _mock_fastmcp
# Mock the mcp module and all its submodules
_mock_mcp = MagicMock() _mock_mcp = MagicMock()
_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator _mock_mcp.tool = lambda: lambda f: f
sys.modules["mcp"] = _mock_mcp sys.modules["mcp"] = _mock_mcp
sys.modules["mcp.types"] = MagicMock()
sys.modules["mcp.server"] = MagicMock() sys.modules["mcp.server"] = MagicMock()
sys.modules["mcp.server.auth"] = MagicMock() sys.modules["mcp.server.auth"] = MagicMock()
sys.modules["mcp.server.auth.handlers"] = MagicMock() sys.modules["mcp.server.auth.handlers"] = MagicMock()
@ -21,7 +38,7 @@ sys.modules["mcp.server.fastmcp.server"] = MagicMock()
# Also mock the memory.api.MCP.base module to avoid MCP imports # Also mock the memory.api.MCP.base module to avoid MCP imports
_mock_base = MagicMock() _mock_base = MagicMock()
_mock_base.mcp = MagicMock() _mock_base.mcp = MagicMock()
_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator _mock_base.mcp.tool = lambda: lambda f: f
sys.modules["memory.api.MCP.base"] = _mock_base sys.modules["memory.api.MCP.base"] = _mock_base
from memory.common.db.models import GithubItem from memory.common.db.models import GithubItem
@ -189,7 +206,7 @@ def sample_issues(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_no_filters(db_session, sample_issues): async def test_list_github_issues_no_filters(db_session, sample_issues):
"""Test listing all issues without filters.""" """Test listing all issues without filters."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues() results = await list_github_issues()
@ -203,7 +220,7 @@ async def test_list_github_issues_no_filters(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_repo(db_session, sample_issues): async def test_list_github_issues_filter_by_repo(db_session, sample_issues):
"""Test filtering by repository.""" """Test filtering by repository."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(repo="owner/repo1") results = await list_github_issues(repo="owner/repo1")
@ -215,7 +232,7 @@ async def test_list_github_issues_filter_by_repo(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_assignee(db_session, sample_issues): async def test_list_github_issues_filter_by_assignee(db_session, sample_issues):
"""Test filtering by assignee.""" """Test filtering by assignee."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(assignee="alice") results = await list_github_issues(assignee="alice")
@ -227,7 +244,7 @@ async def test_list_github_issues_filter_by_assignee(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_author(db_session, sample_issues): async def test_list_github_issues_filter_by_author(db_session, sample_issues):
"""Test filtering by author.""" """Test filtering by author."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(author="alice") results = await list_github_issues(author="alice")
@ -239,7 +256,7 @@ async def test_list_github_issues_filter_by_author(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_state(db_session, sample_issues): async def test_list_github_issues_filter_by_state(db_session, sample_issues):
"""Test filtering by state.""" """Test filtering by state."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
open_results = await list_github_issues(state="open") open_results = await list_github_issues(state="open")
@ -259,7 +276,7 @@ async def test_list_github_issues_filter_by_state(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_kind(db_session, sample_issues): async def test_list_github_issues_filter_by_kind(db_session, sample_issues):
"""Test filtering by kind (issue vs PR).""" """Test filtering by kind (issue vs PR)."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
issues = await list_github_issues(kind="issue") issues = await list_github_issues(kind="issue")
@ -275,7 +292,7 @@ async def test_list_github_issues_filter_by_kind(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_labels(db_session, sample_issues): async def test_list_github_issues_filter_by_labels(db_session, sample_issues):
"""Test filtering by labels.""" """Test filtering by labels."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(labels=["bug"]) results = await list_github_issues(labels=["bug"])
@ -287,7 +304,7 @@ async def test_list_github_issues_filter_by_labels(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_project_status(db_session, sample_issues): async def test_list_github_issues_filter_by_project_status(db_session, sample_issues):
"""Test filtering by project status.""" """Test filtering by project status."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(project_status="In Progress") results = await list_github_issues(project_status="In Progress")
@ -300,7 +317,7 @@ async def test_list_github_issues_filter_by_project_status(db_session, sample_is
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_project_field(db_session, sample_issues): async def test_list_github_issues_filter_by_project_field(db_session, sample_issues):
"""Test filtering by project field (JSONB).""" """Test filtering by project field (JSONB)."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues( results = await list_github_issues(
@ -316,7 +333,7 @@ async def test_list_github_issues_filter_by_project_field(db_session, sample_iss
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_updated_since(db_session, sample_issues): async def test_list_github_issues_filter_by_updated_since(db_session, sample_issues):
"""Test filtering by updated_since.""" """Test filtering by updated_since."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=2)).isoformat() since = (now - timedelta(days=2)).isoformat()
@ -332,7 +349,7 @@ async def test_list_github_issues_filter_by_updated_since(db_session, sample_iss
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_filter_by_updated_before(db_session, sample_issues): async def test_list_github_issues_filter_by_updated_before(db_session, sample_issues):
"""Test filtering by updated_before (stale issues).""" """Test filtering by updated_before (stale issues)."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
before = (now - timedelta(days=30)).isoformat() before = (now - timedelta(days=30)).isoformat()
@ -348,7 +365,7 @@ async def test_list_github_issues_filter_by_updated_before(db_session, sample_is
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_order_by_created(db_session, sample_issues): async def test_list_github_issues_order_by_created(db_session, sample_issues):
"""Test ordering by created date.""" """Test ordering by created date."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(order_by="created") results = await list_github_issues(order_by="created")
@ -360,7 +377,7 @@ async def test_list_github_issues_order_by_created(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_limit(db_session, sample_issues): async def test_list_github_issues_limit(db_session, sample_issues):
"""Test limiting results.""" """Test limiting results."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(limit=2) results = await list_github_issues(limit=2)
@ -371,7 +388,7 @@ async def test_list_github_issues_limit(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_limit_max_enforced(db_session, sample_issues): async def test_list_github_issues_limit_max_enforced(db_session, sample_issues):
"""Test that limit is capped at 200.""" """Test that limit is capped at 200."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
# Request 500 but should be capped at 200 # Request 500 but should be capped at 200
@ -384,7 +401,7 @@ async def test_list_github_issues_limit_max_enforced(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_combined_filters(db_session, sample_issues): async def test_list_github_issues_combined_filters(db_session, sample_issues):
"""Test combining multiple filters.""" """Test combining multiple filters."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues( results = await list_github_issues(
@ -402,7 +419,7 @@ async def test_list_github_issues_combined_filters(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_github_issues_url_construction(db_session, sample_issues): async def test_list_github_issues_url_construction(db_session, sample_issues):
"""Test that URLs are correctly constructed.""" """Test that URLs are correctly constructed."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(kind="issue", limit=1) results = await list_github_issues(kind="issue", limit=1)
@ -425,7 +442,7 @@ async def test_list_github_issues_url_construction(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_issue_details_found(db_session, sample_issues): async def test_github_issue_details_found(db_session, sample_issues):
"""Test getting details for an existing issue.""" """Test getting details for an existing issue."""
from memory.api.MCP.github import github_issue_details from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_issue_details(repo="owner/repo1", number=1) result = await github_issue_details(repo="owner/repo1", number=1)
@ -440,7 +457,7 @@ async def test_github_issue_details_found(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_issue_details_not_found(db_session, sample_issues): async def test_github_issue_details_not_found(db_session, sample_issues):
"""Test getting details for a non-existent issue.""" """Test getting details for a non-existent issue."""
from memory.api.MCP.github import github_issue_details from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
with pytest.raises(ValueError, match="not found"): with pytest.raises(ValueError, match="not found"):
@ -450,7 +467,7 @@ async def test_github_issue_details_not_found(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_issue_details_pr(db_session, sample_issues): async def test_github_issue_details_pr(db_session, sample_issues):
"""Test getting details for a PR.""" """Test getting details for a PR."""
from memory.api.MCP.github import github_issue_details from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_issue_details(repo="owner/repo1", number=50) result = await github_issue_details(repo="owner/repo1", number=50)
@ -468,7 +485,7 @@ async def test_github_issue_details_pr(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_work_summary_by_client(db_session, sample_issues): async def test_github_work_summary_by_client(db_session, sample_issues):
"""Test work summary grouped by client.""" """Test work summary grouped by client."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -489,7 +506,7 @@ async def test_github_work_summary_by_client(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_work_summary_by_status(db_session, sample_issues): async def test_github_work_summary_by_status(db_session, sample_issues):
"""Test work summary grouped by status.""" """Test work summary grouped by status."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -506,7 +523,7 @@ async def test_github_work_summary_by_status(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_work_summary_by_author(db_session, sample_issues): async def test_github_work_summary_by_author(db_session, sample_issues):
"""Test work summary grouped by author.""" """Test work summary grouped by author."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -522,7 +539,7 @@ async def test_github_work_summary_by_author(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_work_summary_by_repo(db_session, sample_issues): async def test_github_work_summary_by_repo(db_session, sample_issues):
"""Test work summary grouped by repository.""" """Test work summary grouped by repository."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -538,7 +555,7 @@ async def test_github_work_summary_by_repo(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_work_summary_with_until(db_session, sample_issues): async def test_github_work_summary_with_until(db_session, sample_issues):
"""Test work summary with until date.""" """Test work summary with until date."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -553,7 +570,7 @@ async def test_github_work_summary_with_until(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_work_summary_with_repo_filter(db_session, sample_issues): async def test_github_work_summary_with_repo_filter(db_session, sample_issues):
"""Test work summary filtered by repository.""" """Test work summary filtered by repository."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -571,7 +588,7 @@ async def test_github_work_summary_with_repo_filter(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_work_summary_invalid_group_by(db_session, sample_issues): async def test_github_work_summary_invalid_group_by(db_session, sample_issues):
"""Test work summary with invalid group_by value.""" """Test work summary with invalid group_by value."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -584,7 +601,7 @@ async def test_github_work_summary_invalid_group_by(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_work_summary_includes_sample_issues(db_session, sample_issues): async def test_github_work_summary_includes_sample_issues(db_session, sample_issues):
"""Test that work summary includes sample issues.""" """Test that work summary includes sample issues."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -610,7 +627,7 @@ async def test_github_work_summary_includes_sample_issues(db_session, sample_iss
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_repo_overview_basic(db_session, sample_issues): async def test_github_repo_overview_basic(db_session, sample_issues):
"""Test basic repo overview.""" """Test basic repo overview."""
from memory.api.MCP.github import github_repo_overview from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1") result = await github_repo_overview(repo="owner/repo1")
@ -625,7 +642,7 @@ async def test_github_repo_overview_basic(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_repo_overview_counts(db_session, sample_issues): async def test_github_repo_overview_counts(db_session, sample_issues):
"""Test repo overview counts are correct.""" """Test repo overview counts are correct."""
from memory.api.MCP.github import github_repo_overview from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1") result = await github_repo_overview(repo="owner/repo1")
@ -638,7 +655,7 @@ async def test_github_repo_overview_counts(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_repo_overview_status_breakdown(db_session, sample_issues): async def test_github_repo_overview_status_breakdown(db_session, sample_issues):
"""Test repo overview includes status breakdown.""" """Test repo overview includes status breakdown."""
from memory.api.MCP.github import github_repo_overview from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1") result = await github_repo_overview(repo="owner/repo1")
@ -651,7 +668,7 @@ async def test_github_repo_overview_status_breakdown(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_repo_overview_top_assignees(db_session, sample_issues): async def test_github_repo_overview_top_assignees(db_session, sample_issues):
"""Test repo overview includes top assignees.""" """Test repo overview includes top assignees."""
from memory.api.MCP.github import github_repo_overview from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1") result = await github_repo_overview(repo="owner/repo1")
@ -663,7 +680,7 @@ async def test_github_repo_overview_top_assignees(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_repo_overview_labels(db_session, sample_issues): async def test_github_repo_overview_labels(db_session, sample_issues):
"""Test repo overview includes labels.""" """Test repo overview includes labels."""
from memory.api.MCP.github import github_repo_overview from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1") result = await github_repo_overview(repo="owner/repo1")
@ -676,7 +693,7 @@ async def test_github_repo_overview_labels(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_repo_overview_last_updated(db_session, sample_issues): async def test_github_repo_overview_last_updated(db_session, sample_issues):
"""Test repo overview includes last updated timestamp.""" """Test repo overview includes last updated timestamp."""
from memory.api.MCP.github import github_repo_overview from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1") result = await github_repo_overview(repo="owner/repo1")
@ -688,7 +705,7 @@ async def test_github_repo_overview_last_updated(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_repo_overview_empty_repo(db_session): async def test_github_repo_overview_empty_repo(db_session):
"""Test repo overview for a repo with no issues.""" """Test repo overview for a repo with no issues."""
from memory.api.MCP.github import github_repo_overview from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="nonexistent/repo") result = await github_repo_overview(repo="nonexistent/repo")
@ -704,7 +721,7 @@ async def test_github_repo_overview_empty_repo(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_github_issues_basic(db_session, sample_issues): async def test_search_github_issues_basic(db_session, sample_issues):
"""Test basic search functionality.""" """Test basic search functionality."""
from memory.api.MCP.github import search_github_issues from memory.api.MCP.servers.github import search_github_issues
mock_search_result = Mock() mock_search_result = Mock()
mock_search_result.id = sample_issues[0].id mock_search_result.id = sample_issues[0].id
@ -723,7 +740,7 @@ async def test_search_github_issues_basic(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_github_issues_with_repo_filter(db_session, sample_issues): async def test_search_github_issues_with_repo_filter(db_session, sample_issues):
"""Test search with repository filter.""" """Test search with repository filter."""
from memory.api.MCP.github import search_github_issues from memory.api.MCP.servers.github import search_github_issues
mock_search_result = Mock() mock_search_result = Mock()
mock_search_result.id = sample_issues[0].id mock_search_result.id = sample_issues[0].id
@ -745,7 +762,7 @@ async def test_search_github_issues_with_repo_filter(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_github_issues_with_state_filter(db_session, sample_issues): async def test_search_github_issues_with_state_filter(db_session, sample_issues):
"""Test search with state filter.""" """Test search with state filter."""
from memory.api.MCP.github import search_github_issues from memory.api.MCP.servers.github import search_github_issues
mock_search_result = Mock() mock_search_result = Mock()
mock_search_result.id = sample_issues[0].id mock_search_result.id = sample_issues[0].id
@ -762,7 +779,7 @@ async def test_search_github_issues_with_state_filter(db_session, sample_issues)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_github_issues_limit(db_session, sample_issues): async def test_search_github_issues_limit(db_session, sample_issues):
"""Test search respects limit.""" """Test search respects limit."""
from memory.api.MCP.github import search_github_issues from memory.api.MCP.servers.github import search_github_issues
mock_results = [Mock(id=issue.id, score=0.9 - i * 0.1) for i, issue in enumerate(sample_issues[:3])] mock_results = [Mock(id=issue.id, score=0.9 - i * 0.1) for i, issue in enumerate(sample_issues[:3])]
@ -779,7 +796,7 @@ async def test_search_github_issues_limit(db_session, sample_issues):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_github_issues_uses_github_modality(db_session, sample_issues): async def test_search_github_issues_uses_github_modality(db_session, sample_issues):
"""Test that search uses github modality.""" """Test that search uses github modality."""
from memory.api.MCP.github import search_github_issues from memory.api.MCP.servers.github import search_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search: with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search:
@ -797,7 +814,7 @@ async def test_search_github_issues_uses_github_modality(db_session, sample_issu
def test_build_github_url_issue(): def test_build_github_url_issue():
"""Test URL construction for issues.""" """Test URL construction for issues."""
from memory.api.MCP.github import _build_github_url from memory.api.MCP.servers.github import _build_github_url
url = _build_github_url("owner/repo", 123, "issue") url = _build_github_url("owner/repo", 123, "issue")
assert url == "https://github.com/owner/repo/issues/123" assert url == "https://github.com/owner/repo/issues/123"
@ -805,7 +822,7 @@ def test_build_github_url_issue():
def test_build_github_url_pr(): def test_build_github_url_pr():
"""Test URL construction for PRs.""" """Test URL construction for PRs."""
from memory.api.MCP.github import _build_github_url from memory.api.MCP.servers.github import _build_github_url
url = _build_github_url("owner/repo", 456, "pr") url = _build_github_url("owner/repo", 456, "pr")
assert url == "https://github.com/owner/repo/pull/456" assert url == "https://github.com/owner/repo/pull/456"
@ -813,7 +830,7 @@ def test_build_github_url_pr():
def test_build_github_url_no_number(): def test_build_github_url_no_number():
"""Test URL construction without number.""" """Test URL construction without number."""
from memory.api.MCP.github import _build_github_url from memory.api.MCP.servers.github import _build_github_url
url = _build_github_url("owner/repo", None, "issue") url = _build_github_url("owner/repo", None, "issue")
assert url == "https://github.com/owner/repo" assert url == "https://github.com/owner/repo"
@ -821,7 +838,7 @@ def test_build_github_url_no_number():
def test_serialize_issue_basic(db_session, sample_issues): def test_serialize_issue_basic(db_session, sample_issues):
"""Test issue serialization.""" """Test issue serialization."""
from memory.api.MCP.github import _serialize_issue from memory.api.MCP.servers.github import _serialize_issue
issue = sample_issues[0] issue = sample_issues[0]
result = _serialize_issue(issue) result = _serialize_issue(issue)
@ -838,7 +855,7 @@ def test_serialize_issue_basic(db_session, sample_issues):
def test_serialize_issue_with_content(db_session, sample_issues): def test_serialize_issue_with_content(db_session, sample_issues):
"""Test issue serialization with content.""" """Test issue serialization with content."""
from memory.api.MCP.github import _serialize_issue from memory.api.MCP.servers.github import _serialize_issue
issue = sample_issues[0] issue = sample_issues[0]
result = _serialize_issue(issue, include_content=True) result = _serialize_issue(issue, include_content=True)
@ -865,7 +882,7 @@ async def test_list_github_issues_ordering(
db_session, sample_issues, order_by, expected_first_number db_session, sample_issues, order_by, expected_first_number
): ):
"""Test different ordering options.""" """Test different ordering options."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(order_by=order_by) results = await list_github_issues(order_by=order_by)
@ -880,7 +897,7 @@ async def test_list_github_issues_ordering(
) )
async def test_github_work_summary_all_group_by_options(db_session, sample_issues, group_by): async def test_github_work_summary_all_group_by_options(db_session, sample_issues, group_by):
"""Test all valid group_by options.""" """Test all valid group_by options."""
from memory.api.MCP.github import github_work_summary from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat() since = (now - timedelta(days=30)).isoformat()
@ -906,7 +923,7 @@ async def test_list_github_issues_label_filtering(
db_session, sample_issues, labels, expected_count db_session, sample_issues, labels, expected_count
): ):
"""Test various label filtering scenarios.""" """Test various label filtering scenarios."""
from memory.api.MCP.github import list_github_issues from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(labels=labels) results = await list_github_issues(labels=labels)
@ -1014,7 +1031,7 @@ def sample_pr_with_data(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_data): async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_data):
"""Test that github_issue_details includes PR data for PRs.""" """Test that github_issue_details includes PR data for PRs."""
from memory.api.MCP.github import github_issue_details from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_issue_details(repo="owner/repo1", number=999) result = await github_issue_details(repo="owner/repo1", number=999)
@ -1034,7 +1051,7 @@ async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_issues): async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_issues):
"""Test that github_issue_details does not include pr_data for issues.""" """Test that github_issue_details does not include pr_data for issues."""
from memory.api.MCP.github import github_issue_details from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_issue_details(repo="owner/repo1", number=1) result = await github_issue_details(repo="owner/repo1", number=1)
@ -1045,7 +1062,7 @@ async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_iss
def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data): def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data):
"""Test that _serialize_issue includes pr_data when include_content=True.""" """Test that _serialize_issue includes pr_data when include_content=True."""
from memory.api.MCP.github import _serialize_issue from memory.api.MCP.servers.github import _serialize_issue
result = _serialize_issue(sample_pr_with_data, include_content=True) result = _serialize_issue(sample_pr_with_data, include_content=True)
@ -1056,7 +1073,7 @@ def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data):
def test_serialize_issue_no_pr_data_without_content(db_session, sample_pr_with_data): def test_serialize_issue_no_pr_data_without_content(db_session, sample_pr_with_data):
"""Test that _serialize_issue excludes pr_data when include_content=False.""" """Test that _serialize_issue excludes pr_data when include_content=False."""
from memory.api.MCP.github import _serialize_issue from memory.api.MCP.servers.github import _serialize_issue
result = _serialize_issue(sample_pr_with_data, include_content=False) result = _serialize_issue(sample_pr_with_data, include_content=False)

View File

@ -4,10 +4,27 @@ import sys
import pytest import pytest
from unittest.mock import AsyncMock, Mock, MagicMock, patch from unittest.mock import AsyncMock, Mock, MagicMock, patch
# Mock the mcp module and all its submodules before importing anything that uses it # Mock FastMCP - this creates a decorator factory that passes through the function unchanged
class MockFastMCP:
def __init__(self, name):
self.name = name
def tool(self):
def decorator(func):
return func
return decorator
# Mock the fastmcp module before importing anything that uses it
_mock_fastmcp = MagicMock()
_mock_fastmcp.FastMCP = MockFastMCP
sys.modules["fastmcp"] = _mock_fastmcp
# Mock the mcp module and all its submodules
_mock_mcp = MagicMock() _mock_mcp = MagicMock()
_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator _mock_mcp.tool = lambda: lambda f: f
sys.modules["mcp"] = _mock_mcp sys.modules["mcp"] = _mock_mcp
sys.modules["mcp.types"] = MagicMock()
sys.modules["mcp.server"] = MagicMock() sys.modules["mcp.server"] = MagicMock()
sys.modules["mcp.server.auth"] = MagicMock() sys.modules["mcp.server.auth"] = MagicMock()
sys.modules["mcp.server.auth.handlers"] = MagicMock() sys.modules["mcp.server.auth.handlers"] = MagicMock()
@ -20,7 +37,7 @@ sys.modules["mcp.server.fastmcp.server"] = MagicMock()
# Also mock the memory.api.MCP.base module to avoid MCP imports # Also mock the memory.api.MCP.base module to avoid MCP imports
_mock_base = MagicMock() _mock_base = MagicMock()
_mock_base.mcp = MagicMock() _mock_base.mcp = MagicMock()
_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator _mock_base.mcp.tool = lambda: lambda f: f
sys.modules["memory.api.MCP.base"] = _mock_base sys.modules["memory.api.MCP.base"] = _mock_base
from memory.common.db.models import Person from memory.common.db.models import Person
@ -97,7 +114,7 @@ def sample_people(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_person_success(db_session): async def test_add_person_success(db_session):
"""Test adding a new person.""" """Test adding a new person."""
from memory.api.MCP.people import add_person from memory.api.MCP.servers.people import add_person
mock_task = Mock() mock_task = Mock()
mock_task.id = "task-123" mock_task.id = "task-123"
@ -128,7 +145,7 @@ async def test_add_person_success(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_person_already_exists(db_session, sample_people): async def test_add_person_already_exists(db_session, sample_people):
"""Test adding a person that already exists.""" """Test adding a person that already exists."""
from memory.api.MCP.people import add_person from memory.api.MCP.servers.people import add_person
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="already exists"): with pytest.raises(ValueError, match="already exists"):
@ -141,7 +158,7 @@ async def test_add_person_already_exists(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_person_minimal(db_session): async def test_add_person_minimal(db_session):
"""Test adding a person with minimal data.""" """Test adding a person with minimal data."""
from memory.api.MCP.people import add_person from memory.api.MCP.servers.people import add_person
mock_task = Mock() mock_task = Mock()
mock_task.id = "task-456" mock_task.id = "task-456"
@ -166,7 +183,7 @@ async def test_add_person_minimal(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_person_info_success(db_session, sample_people): async def test_update_person_info_success(db_session, sample_people):
"""Test updating a person's info.""" """Test updating a person's info."""
from memory.api.MCP.people import update_person_info from memory.api.MCP.servers.people import update_person_info
mock_task = Mock() mock_task = Mock()
mock_task.id = "task-789" mock_task.id = "task-789"
@ -188,7 +205,7 @@ async def test_update_person_info_success(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_person_info_not_found(db_session, sample_people): async def test_update_person_info_not_found(db_session, sample_people):
"""Test updating a person that doesn't exist.""" """Test updating a person that doesn't exist."""
from memory.api.MCP.people import update_person_info from memory.api.MCP.servers.people import update_person_info
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="not found"): with pytest.raises(ValueError, match="not found"):
@ -201,7 +218,7 @@ async def test_update_person_info_not_found(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_person_info_with_merge_params(db_session, sample_people): async def test_update_person_info_with_merge_params(db_session, sample_people):
"""Test that update passes all merge parameters.""" """Test that update passes all merge parameters."""
from memory.api.MCP.people import update_person_info from memory.api.MCP.servers.people import update_person_info
mock_task = Mock() mock_task = Mock()
mock_task.id = "task-merge" mock_task.id = "task-merge"
@ -234,7 +251,7 @@ async def test_update_person_info_with_merge_params(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_person_found(db_session, sample_people): async def test_get_person_found(db_session, sample_people):
"""Test getting a person that exists.""" """Test getting a person that exists."""
from memory.api.MCP.people import get_person from memory.api.MCP.servers.people import get_person
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await get_person(identifier="alice_chen") result = await get_person(identifier="alice_chen")
@ -251,7 +268,7 @@ async def test_get_person_found(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_person_not_found(db_session, sample_people): async def test_get_person_not_found(db_session, sample_people):
"""Test getting a person that doesn't exist.""" """Test getting a person that doesn't exist."""
from memory.api.MCP.people import get_person from memory.api.MCP.servers.people import get_person
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await get_person(identifier="nonexistent_person") result = await get_person(identifier="nonexistent_person")
@ -267,7 +284,7 @@ async def test_get_person_not_found(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_people_no_filters(db_session, sample_people): async def test_list_people_no_filters(db_session, sample_people):
"""Test listing all people without filters.""" """Test listing all people without filters."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people() results = await list_people()
@ -282,7 +299,7 @@ async def test_list_people_no_filters(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_people_filter_by_tags(db_session, sample_people): async def test_list_people_filter_by_tags(db_session, sample_people):
"""Test filtering by tags.""" """Test filtering by tags."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=["work"]) results = await list_people(tags=["work"])
@ -294,7 +311,7 @@ async def test_list_people_filter_by_tags(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_people_filter_by_search(db_session, sample_people): async def test_list_people_filter_by_search(db_session, sample_people):
"""Test filtering by search term.""" """Test filtering by search term."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search="alice") results = await list_people(search="alice")
@ -306,7 +323,7 @@ async def test_list_people_filter_by_search(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_people_search_in_notes(db_session, sample_people): async def test_list_people_search_in_notes(db_session, sample_people):
"""Test that search works on notes content.""" """Test that search works on notes content."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search="climbing") results = await list_people(search="climbing")
@ -318,7 +335,7 @@ async def test_list_people_search_in_notes(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_people_limit(db_session, sample_people): async def test_list_people_limit(db_session, sample_people):
"""Test limiting results.""" """Test limiting results."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(limit=1) results = await list_people(limit=1)
@ -329,7 +346,7 @@ async def test_list_people_limit(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_people_limit_max_enforced(db_session, sample_people): async def test_list_people_limit_max_enforced(db_session, sample_people):
"""Test that limit is capped at 200.""" """Test that limit is capped at 200."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
# Request 500 but should be capped at 200 # Request 500 but should be capped at 200
@ -342,7 +359,7 @@ async def test_list_people_limit_max_enforced(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_people_combined_filters(db_session, sample_people): async def test_list_people_combined_filters(db_session, sample_people):
"""Test combining tag and search filters.""" """Test combining tag and search filters."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=["work"], search="chen") results = await list_people(tags=["work"], search="chen")
@ -359,7 +376,7 @@ async def test_list_people_combined_filters(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_person_success(db_session, sample_people): async def test_delete_person_success(db_session, sample_people):
"""Test deleting a person.""" """Test deleting a person."""
from memory.api.MCP.people import delete_person from memory.api.MCP.servers.people import delete_person
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await delete_person(identifier="bob_smith") result = await delete_person(identifier="bob_smith")
@ -376,7 +393,7 @@ async def test_delete_person_success(db_session, sample_people):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_person_not_found(db_session, sample_people): async def test_delete_person_not_found(db_session, sample_people):
"""Test deleting a person that doesn't exist.""" """Test deleting a person that doesn't exist."""
from memory.api.MCP.people import delete_person from memory.api.MCP.servers.people import delete_person
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="not found"): with pytest.raises(ValueError, match="not found"):
@ -390,7 +407,7 @@ async def test_delete_person_not_found(db_session, sample_people):
def test_person_to_dict(sample_people): def test_person_to_dict(sample_people):
"""Test the _person_to_dict helper function.""" """Test the _person_to_dict helper function."""
from memory.api.MCP.people import _person_to_dict from memory.api.MCP.servers.people import _person_to_dict
person = sample_people[0] person = sample_people[0]
result = _person_to_dict(person) result = _person_to_dict(person)
@ -406,7 +423,7 @@ def test_person_to_dict(sample_people):
def test_person_to_dict_empty_fields(db_session): def test_person_to_dict_empty_fields(db_session):
"""Test _person_to_dict with empty optional fields.""" """Test _person_to_dict with empty optional fields."""
from memory.api.MCP.people import _person_to_dict from memory.api.MCP.servers.people import _person_to_dict
person = Person( person = Person(
identifier="empty_person", identifier="empty_person",
@ -448,7 +465,7 @@ def test_person_to_dict_empty_fields(db_session):
) )
async def test_list_people_various_tags(db_session, sample_people, tag, expected_count): async def test_list_people_various_tags(db_session, sample_people, tag, expected_count):
"""Test filtering by various tags.""" """Test filtering by various tags."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=[tag]) results = await list_people(tags=[tag])
@ -473,7 +490,7 @@ async def test_list_people_various_searches(
db_session, sample_people, search_term, expected_identifiers db_session, sample_people, search_term, expected_identifiers
): ):
"""Test search with various terms.""" """Test search with various terms."""
from memory.api.MCP.people import list_people from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session): with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search=search_term) results = await list_people(search=search_term)

View File

@ -69,8 +69,12 @@ class TestSearchResult:
result = SearchResult.from_source_item(source, chunks) result = SearchResult.from_source_item(source, chunks)
assert result.search_score == 0.9 assert result.search_score == 0.9
def test_search_score_multiple_chunks_uses_mean(self): def test_search_score_multiple_chunks_uses_max(self):
"""Multiple chunks should use mean of relevance scores, not sum.""" """Multiple chunks should use max of relevance scores.
Using max finds documents with at least one highly relevant section,
which is better for 'half-remembered' searches where users recall one detail.
"""
source = self._make_source_item() source = self._make_source_item()
chunks = [ chunks = [
self._make_chunk(0.9), self._make_chunk(0.9),
@ -79,8 +83,8 @@ class TestSearchResult:
] ]
result = SearchResult.from_source_item(source, chunks) result = SearchResult.from_source_item(source, chunks)
# Mean of 0.9, 0.7, 0.8 = 0.8 # Max of 0.9, 0.7, 0.8 = 0.9
assert result.search_score == pytest.approx(0.8) assert result.search_score == pytest.approx(0.9)
def test_search_score_empty_chunks(self): def test_search_score_empty_chunks(self):
"""Empty chunk list should result in None or 0 score.""" """Empty chunk list should result in None or 0 score."""