exclude stuff

This commit is contained in:
Daniel O'Connell 2025-12-29 17:03:08 +01:00
parent 59c45ff1fb
commit 9cf71c9336
12 changed files with 546 additions and 86 deletions

View File

@ -0,0 +1,35 @@
"""Add exclude_folder_ids to google_folders
Revision ID: add_exclude_folder_ids
Revises: 20251229_120000_add_google_drive
Create Date: 2025-12-29 15:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "add_exclude_folder_ids"
down_revision: Union[str, None] = "f2a3b4c5d6e7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"google_folders",
sa.Column(
"exclude_folder_ids",
postgresql.ARRAY(sa.Text()),
nullable=False,
server_default="{}",
),
)
def downgrade() -> None:
op.drop_column("google_folders", "exclude_folder_ids")

View File

@ -1936,3 +1936,97 @@ a.folder-item-name:hover {
.add-btn.secondary:hover {
background: #edf2f7;
}
/* === Exclusion Browser === */
.exclusion-browser {
display: flex;
flex-direction: column;
min-height: 400px;
max-height: 60vh;
}
.excluded-list {
background: #fffbeb;
border: 1px solid #f59e0b;
border-radius: 6px;
padding: 0.75rem;
margin-bottom: 0.75rem;
}
.excluded-list h5 {
font-size: 0.85rem;
color: #92400e;
margin-bottom: 0.5rem;
}
.excluded-items {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
}
.excluded-item {
display: flex;
align-items: center;
gap: 0.25rem;
background: #fef3c7;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.85rem;
}
.excluded-name {
color: #92400e;
}
.remove-exclusion-btn {
background: none;
border: none;
color: #b45309;
cursor: pointer;
padding: 0 0.25rem;
font-size: 0.75rem;
line-height: 1;
}
.remove-exclusion-btn:hover {
color: #dc2626;
}
.folder-item.excluded {
background: #fffbeb;
border-color: #f59e0b;
}
.exclusion-badge {
background: #f59e0b;
color: white;
padding: 0.125rem 0.375rem;
border-radius: 3px;
font-size: 0.7rem;
margin-left: auto;
margin-right: 0.5rem;
}
.exclusions-btn {
background: #fffbeb;
color: #b45309;
border: 1px solid #f59e0b;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.75rem;
cursor: pointer;
transition: all 0.15s ease;
}
.exclusions-btn:hover {
background: #fef3c7;
color: #92400e;
}
.setting-badge.warning {
background: #fffbeb;
color: #b45309;
border: 1px solid #f59e0b;
}

View File

@ -1,6 +1,6 @@
import { useState, useEffect, useCallback } from 'react'
import { Link } from 'react-router-dom'
import { useSources, EmailAccount, ArticleFeed, GithubAccount, GoogleAccount, GoogleOAuthConfig, DriveItem, BrowseResponse, GoogleFolderCreate } from '@/hooks/useSources'
import { useSources, EmailAccount, ArticleFeed, GithubAccount, GoogleAccount, GoogleFolder, GoogleOAuthConfig, DriveItem, BrowseResponse, GoogleFolderCreate } from '@/hooks/useSources'
import {
SourceCard,
Modal,
@ -956,6 +956,7 @@ const GoogleDrivePanel = () => {
const [error, setError] = useState<string | null>(null)
const [addingFolderTo, setAddingFolderTo] = useState<number | null>(null)
const [browsingFoldersFor, setBrowsingFoldersFor] = useState<number | null>(null)
const [managingExclusionsFor, setManagingExclusionsFor] = useState<{ accountId: number; folder: GoogleFolder } | null>(null)
const [uploadingConfig, setUploadingConfig] = useState(false)
const loadData = useCallback(async () => {
@ -1044,6 +1045,12 @@ const GoogleDrivePanel = () => {
loadData()
}
const handleUpdateExclusions = async (accountId: number, folderId: number, excludeIds: string[]) => {
await updateGoogleFolder(accountId, folderId, { exclude_folder_ids: excludeIds })
setManagingExclusionsFor(null)
loadData()
}
const handleDeleteFolder = async (accountId: number, folderId: number) => {
await deleteGoogleFolder(accountId, folderId)
loadData()
@ -1188,6 +1195,11 @@ const GoogleDrivePanel = () => {
<div className="folder-settings">
{folder.recursive && <span className="setting-badge">Recursive</span>}
{folder.include_shared && <span className="setting-badge">Shared</span>}
{folder.exclude_folder_ids.length > 0 && (
<span className="setting-badge warning">
{folder.exclude_folder_ids.length} excluded
</span>
)}
</div>
<div className="folder-actions">
<SyncButton
@ -1195,6 +1207,14 @@ const GoogleDrivePanel = () => {
disabled={!folder.active || !account.active}
label="Sync"
/>
{folder.recursive && (
<button
className="exclusions-btn"
onClick={() => setManagingExclusionsFor({ accountId: account.id, folder })}
>
Exclusions
</button>
)}
<button
className="toggle-btn"
onClick={() => handleToggleFolderActive(account.id, folder.id, folder.active)}
@ -1233,6 +1253,15 @@ const GoogleDrivePanel = () => {
onCancel={() => setBrowsingFoldersFor(null)}
/>
)}
{managingExclusionsFor && (
<ExclusionBrowser
accountId={managingExclusionsFor.accountId}
folder={managingExclusionsFor.folder}
onSave={(excludeIds) => handleUpdateExclusions(managingExclusionsFor.accountId, managingExclusionsFor.folder.id, excludeIds)}
onCancel={() => setManagingExclusionsFor(null)}
/>
)}
</div>
)
}
@ -1422,6 +1451,195 @@ const FolderBrowser = ({ accountId, onSelect, onCancel }: FolderBrowserProps) =>
)
}
// === Exclusion Browser ===
interface ExclusionBrowserProps {
accountId: number
folder: GoogleFolder
onSave: (excludeIds: string[]) => void
onCancel: () => void
}
interface ExcludedFolder {
id: string
name: string
path: string
}
const ExclusionBrowser = ({ accountId, folder, onSave, onCancel }: ExclusionBrowserProps) => {
const { browseGoogleDrive } = useSources()
const [path, setPath] = useState<PathItem[]>([{ id: folder.folder_id, name: folder.folder_name }])
const [items, setItems] = useState<DriveItem[]>([])
const [excluded, setExcluded] = useState<Map<string, ExcludedFolder>>(() => {
// Initialize with current exclusions (we don't have names, so use ID as name)
const map = new Map<string, ExcludedFolder>()
for (const id of folder.exclude_folder_ids) {
map.set(id, { id, name: id, path: '(previously excluded)' })
}
return map
})
const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
const currentFolderId = path[path.length - 1].id
const currentPath = path.map(p => p.name).join(' > ')
const loadFolder = useCallback(async (folderId: string) => {
setLoading(true)
setError(null)
try {
const response = await browseGoogleDrive(accountId, folderId)
// Only show folders, not files
setItems(response.items.filter(item => item.is_folder))
} catch (e) {
setError(e instanceof Error ? e.message : 'Failed to load folder')
} finally {
setLoading(false)
}
}, [accountId, browseGoogleDrive])
useEffect(() => {
loadFolder(currentFolderId)
}, [currentFolderId, loadFolder])
const navigateToFolder = (item: DriveItem) => {
setPath([...path, { id: item.id, name: item.name }])
}
const navigateToPathIndex = (index: number) => {
setPath(path.slice(0, index + 1))
}
const toggleExclude = (item: DriveItem) => {
const newExcluded = new Map(excluded)
if (newExcluded.has(item.id)) {
newExcluded.delete(item.id)
} else {
newExcluded.set(item.id, {
id: item.id,
name: item.name,
path: currentPath + ' > ' + item.name,
})
}
setExcluded(newExcluded)
}
const removeExclusion = (id: string) => {
const newExcluded = new Map(excluded)
newExcluded.delete(id)
setExcluded(newExcluded)
}
const handleSave = () => {
onSave(Array.from(excluded.keys()))
}
return (
<Modal title={`Manage Exclusions: ${folder.folder_name}`} onClose={onCancel}>
<div className="exclusion-browser">
{/* Current exclusions */}
{excluded.size > 0 && (
<div className="excluded-list">
<h5>Excluded Folders ({excluded.size})</h5>
<div className="excluded-items">
{Array.from(excluded.values()).map(item => (
<div key={item.id} className="excluded-item">
<span className="excluded-name" title={item.path}>
📁 {item.name}
</span>
<button
className="remove-exclusion-btn"
onClick={() => removeExclusion(item.id)}
title="Remove exclusion"
>
</button>
</div>
))}
</div>
</div>
)}
{/* Breadcrumb */}
<div className="folder-breadcrumb">
{path.map((item, index) => (
<span key={item.id}>
{index > 0 && <span className="breadcrumb-sep">&gt;</span>}
<button
className={`breadcrumb-item ${index === path.length - 1 ? 'current' : ''}`}
onClick={() => navigateToPathIndex(index)}
disabled={index === path.length - 1}
>
{item.name}
</button>
</span>
))}
</div>
{/* Content */}
{error && <div className="form-error">{error}</div>}
{loading ? (
<div className="folder-loading">Loading...</div>
) : items.length === 0 ? (
<div className="folder-empty">No subfolders in this folder</div>
) : (
<div className="folder-list">
{items.map(item => (
<div key={item.id} className={`folder-item ${excluded.has(item.id) ? 'excluded' : ''}`}>
<label className="folder-item-checkbox">
<input
type="checkbox"
checked={excluded.has(item.id)}
onChange={() => toggleExclude(item)}
/>
</label>
<span className="folder-item-icon">📁</span>
<a
className="folder-item-name"
href={`https://drive.google.com/drive/folders/${item.id}`}
target="_blank"
rel="noopener noreferrer"
onClick={(e) => e.stopPropagation()}
>
{item.name}
</a>
{excluded.has(item.id) && (
<span className="exclusion-badge">Excluded</span>
)}
<button
className="folder-item-enter"
onClick={() => navigateToFolder(item)}
title="Browse subfolder"
>
</button>
</div>
))}
</div>
)}
{/* Footer */}
<div className="folder-browser-footer">
<span className="selected-count">
{excluded.size} folder{excluded.size !== 1 ? 's' : ''} excluded
</span>
<div className="folder-browser-actions">
<button type="button" className="cancel-btn" onClick={onCancel}>Cancel</button>
<button
type="button"
className="submit-btn"
onClick={handleSave}
>
Save Exclusions
</button>
</div>
</div>
</div>
</Modal>
)
}
interface GoogleFolderFormProps {
accountId: number
onSubmit: (data: any) => Promise<void>

View File

@ -177,6 +177,7 @@ export interface GoogleFolder {
check_interval: number
last_sync_at: string | null
active: boolean
exclude_folder_ids: string[]
}
export interface GoogleAccount {
@ -205,6 +206,7 @@ export interface GoogleFolderUpdate {
tags?: string[]
check_interval?: number
active?: boolean
exclude_folder_ids?: string[]
}
// Types for Google Drive browsing

View File

@ -29,6 +29,9 @@ from memory.common.db.models import (
ScheduledLLMCall,
SourceItem,
User,
GoogleDoc,
GoogleFolder,
GoogleAccount,
)
from memory.common.db.models.discord import DiscordChannel, DiscordServer, DiscordUser
@ -394,6 +397,42 @@ class GithubItemAdmin(ModelView, model=GithubItem):
column_sortable_list = ["github_updated_at", "created_at"]
class GoogleDocAdmin(ModelView, model=GoogleDoc):
column_list = source_columns(
GoogleDoc,
"title",
"folder_path",
"owner",
"last_modified_by",
"word_count",
"content_hash",
)
column_searchable_list = ["title", "folder_path", "owner", "last_modified_by", "id"]
column_sortable_list = ["google_modified_at", "created_at"]
class GoogleFolderAdmin(ModelView, model=GoogleFolder):
column_list = source_columns(
GoogleFolder, "folder_name", "folder_path", "account", "active"
)
column_searchable_list = ["folder_name", "folder_path", "id"]
column_sortable_list = ["last_sync_at", "created_at"]
class GoogleAccountAdmin(ModelView, model=GoogleAccount):
column_list = [
"id",
"name",
"email",
"active",
"last_sync_at",
"created_at",
"updated_at",
]
column_searchable_list = ["name", "email", "id"]
column_sortable_list = ["last_sync_at", "created_at"]
def setup_admin(admin: Admin):
"""Add all admin views to the admin instance with OAuth protection."""
admin.add_view(SourceItemAdmin)
@ -421,3 +460,6 @@ def setup_admin(admin: Admin):
admin.add_view(GithubAccountAdmin)
admin.add_view(GithubRepoAdmin)
admin.add_view(GithubItemAdmin)
admin.add_view(GoogleDocAdmin)
admin.add_view(GoogleFolderAdmin)
admin.add_view(GoogleAccountAdmin)

View File

@ -60,6 +60,7 @@ class FolderUpdate(BaseModel):
tags: list[str] | None = None
check_interval: int | None = None
active: bool | None = None
exclude_folder_ids: list[str] | None = None
class FolderResponse(BaseModel):
@ -73,6 +74,7 @@ class FolderResponse(BaseModel):
check_interval: int
last_sync_at: str | None
active: bool
exclude_folder_ids: list[str]
class AccountResponse(BaseModel):
@ -370,6 +372,7 @@ def list_accounts(
folder.last_sync_at.isoformat() if folder.last_sync_at else None
),
active=cast(bool, folder.active),
exclude_folder_ids=cast(list[str], folder.exclude_folder_ids) or [],
)
for folder in account.folders
],
@ -531,6 +534,7 @@ def add_folder(
check_interval=cast(int, new_folder.check_interval),
last_sync_at=None,
active=cast(bool, new_folder.active),
exclude_folder_ids=[],
)
@ -567,6 +571,8 @@ def update_folder(
folder.check_interval = updates.check_interval
if updates.active is not None:
folder.active = updates.active
if updates.exclude_folder_ids is not None:
folder.exclude_folder_ids = updates.exclude_folder_ids
db.commit()
db.refresh(folder)
@ -584,6 +590,7 @@ def update_folder(
folder.last_sync_at.isoformat() if folder.last_sync_at else None
),
active=cast(bool, folder.active),
exclude_folder_ids=cast(list[str], folder.exclude_folder_ids) or [],
)

View File

@ -350,6 +350,9 @@ class GoogleFolder(Base):
# File type filters (empty = all text documents)
mime_type_filter = Column(ARRAY(Text), nullable=False, server_default="{}")
# Excluded subfolder IDs (skip these when syncing recursively)
exclude_folder_ids = Column(ARRAY(Text), nullable=False, server_default="{}")
# Tags to apply to all documents from this folder
tags = Column(ARRAY(Text), nullable=False, server_default="{}")

View File

@ -130,9 +130,19 @@ class GoogleDriveClient:
recursive: bool = True,
since: datetime | None = None,
page_size: int = 100,
exclude_folder_ids: set[str] | None = None,
) -> Generator[dict, None, None]:
"""List all supported files in a folder with pagination."""
"""List all supported files in a folder with pagination.
Args:
folder_id: The Google Drive folder ID to list
recursive: Whether to recurse into subfolders
since: Only return files modified after this time
page_size: Number of files per API page
exclude_folder_ids: Set of folder IDs to skip during recursive traversal
"""
service = self._get_service()
exclude_folder_ids = exclude_folder_ids or set()
# Build query for supported file types
all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES
@ -167,11 +177,16 @@ class GoogleDriveClient:
for file in response.get("files", []):
if file["mimeType"] == "application/vnd.google-apps.folder":
if recursive:
if recursive and file["id"] not in exclude_folder_ids:
# Recursively list files in subfolder
yield from self.list_files_in_folder(
file["id"], recursive=True, since=since
file["id"],
recursive=True,
since=since,
exclude_folder_ids=exclude_folder_ids,
)
elif file["id"] in exclude_folder_ids:
logger.info(f"Skipping excluded folder: {file['name']} ({file['id']})")
else:
yield file

View File

@ -251,10 +251,16 @@ def sync_google_folder(folder_id: int, force_full: bool = False) -> dict[str, An
# It's a folder - list and sync all files inside
folder_path = client.get_folder_path(google_id)
# Get excluded folder IDs
exclude_ids = set(cast(list[str], folder.exclude_folder_ids) or [])
if exclude_ids:
logger.info(f"Excluding {len(exclude_ids)} folder(s) from sync")
for file_meta in client.list_files_in_folder(
google_id,
recursive=cast(bool, folder.recursive),
since=since,
exclude_folder_ids=exclude_ids,
):
try:
file_data = client.fetch_file(file_meta, folder_path)

View File

@ -5,10 +5,27 @@ import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, Mock, MagicMock, patch
# Mock the mcp module and all its submodules before importing anything that uses it
# Mock FastMCP - this creates a decorator factory that passes through the function unchanged
class MockFastMCP:
def __init__(self, name):
self.name = name
def tool(self):
def decorator(func):
return func
return decorator
# Mock the fastmcp module before importing anything that uses it
_mock_fastmcp = MagicMock()
_mock_fastmcp.FastMCP = MockFastMCP
sys.modules["fastmcp"] = _mock_fastmcp
# Mock the mcp module and all its submodules
_mock_mcp = MagicMock()
_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
_mock_mcp.tool = lambda: lambda f: f
sys.modules["mcp"] = _mock_mcp
sys.modules["mcp.types"] = MagicMock()
sys.modules["mcp.server"] = MagicMock()
sys.modules["mcp.server.auth"] = MagicMock()
sys.modules["mcp.server.auth.handlers"] = MagicMock()
@ -21,7 +38,7 @@ sys.modules["mcp.server.fastmcp.server"] = MagicMock()
# Also mock the memory.api.MCP.base module to avoid MCP imports
_mock_base = MagicMock()
_mock_base.mcp = MagicMock()
_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
_mock_base.mcp.tool = lambda: lambda f: f
sys.modules["memory.api.MCP.base"] = _mock_base
from memory.common.db.models import GithubItem
@ -189,7 +206,7 @@ def sample_issues(db_session):
@pytest.mark.asyncio
async def test_list_github_issues_no_filters(db_session, sample_issues):
"""Test listing all issues without filters."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues()
@ -203,7 +220,7 @@ async def test_list_github_issues_no_filters(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_repo(db_session, sample_issues):
"""Test filtering by repository."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(repo="owner/repo1")
@ -215,7 +232,7 @@ async def test_list_github_issues_filter_by_repo(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_assignee(db_session, sample_issues):
"""Test filtering by assignee."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(assignee="alice")
@ -227,7 +244,7 @@ async def test_list_github_issues_filter_by_assignee(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_author(db_session, sample_issues):
"""Test filtering by author."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(author="alice")
@ -239,7 +256,7 @@ async def test_list_github_issues_filter_by_author(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_state(db_session, sample_issues):
"""Test filtering by state."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
open_results = await list_github_issues(state="open")
@ -259,7 +276,7 @@ async def test_list_github_issues_filter_by_state(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_kind(db_session, sample_issues):
"""Test filtering by kind (issue vs PR)."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
issues = await list_github_issues(kind="issue")
@ -275,7 +292,7 @@ async def test_list_github_issues_filter_by_kind(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_labels(db_session, sample_issues):
"""Test filtering by labels."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(labels=["bug"])
@ -287,7 +304,7 @@ async def test_list_github_issues_filter_by_labels(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_project_status(db_session, sample_issues):
"""Test filtering by project status."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(project_status="In Progress")
@ -300,7 +317,7 @@ async def test_list_github_issues_filter_by_project_status(db_session, sample_is
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_project_field(db_session, sample_issues):
"""Test filtering by project field (JSONB)."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(
@ -316,7 +333,7 @@ async def test_list_github_issues_filter_by_project_field(db_session, sample_iss
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_updated_since(db_session, sample_issues):
"""Test filtering by updated_since."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
now = datetime.now(timezone.utc)
since = (now - timedelta(days=2)).isoformat()
@ -332,7 +349,7 @@ async def test_list_github_issues_filter_by_updated_since(db_session, sample_iss
@pytest.mark.asyncio
async def test_list_github_issues_filter_by_updated_before(db_session, sample_issues):
"""Test filtering by updated_before (stale issues)."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
now = datetime.now(timezone.utc)
before = (now - timedelta(days=30)).isoformat()
@ -348,7 +365,7 @@ async def test_list_github_issues_filter_by_updated_before(db_session, sample_is
@pytest.mark.asyncio
async def test_list_github_issues_order_by_created(db_session, sample_issues):
"""Test ordering by created date."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(order_by="created")
@ -360,7 +377,7 @@ async def test_list_github_issues_order_by_created(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_limit(db_session, sample_issues):
"""Test limiting results."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(limit=2)
@ -371,7 +388,7 @@ async def test_list_github_issues_limit(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_limit_max_enforced(db_session, sample_issues):
"""Test that limit is capped at 200."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
# Request 500 but should be capped at 200
@ -384,7 +401,7 @@ async def test_list_github_issues_limit_max_enforced(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_combined_filters(db_session, sample_issues):
"""Test combining multiple filters."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(
@ -402,7 +419,7 @@ async def test_list_github_issues_combined_filters(db_session, sample_issues):
@pytest.mark.asyncio
async def test_list_github_issues_url_construction(db_session, sample_issues):
"""Test that URLs are correctly constructed."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(kind="issue", limit=1)
@ -425,7 +442,7 @@ async def test_list_github_issues_url_construction(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_issue_details_found(db_session, sample_issues):
"""Test getting details for an existing issue."""
from memory.api.MCP.github import github_issue_details
from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_issue_details(repo="owner/repo1", number=1)
@ -440,7 +457,7 @@ async def test_github_issue_details_found(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_issue_details_not_found(db_session, sample_issues):
"""Test getting details for a non-existent issue."""
from memory.api.MCP.github import github_issue_details
from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session):
with pytest.raises(ValueError, match="not found"):
@ -450,7 +467,7 @@ async def test_github_issue_details_not_found(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_issue_details_pr(db_session, sample_issues):
"""Test getting details for a PR."""
from memory.api.MCP.github import github_issue_details
from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_issue_details(repo="owner/repo1", number=50)
@ -468,7 +485,7 @@ async def test_github_issue_details_pr(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_work_summary_by_client(db_session, sample_issues):
"""Test work summary grouped by client."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -489,7 +506,7 @@ async def test_github_work_summary_by_client(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_work_summary_by_status(db_session, sample_issues):
"""Test work summary grouped by status."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -506,7 +523,7 @@ async def test_github_work_summary_by_status(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_work_summary_by_author(db_session, sample_issues):
"""Test work summary grouped by author."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -522,7 +539,7 @@ async def test_github_work_summary_by_author(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_work_summary_by_repo(db_session, sample_issues):
"""Test work summary grouped by repository."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -538,7 +555,7 @@ async def test_github_work_summary_by_repo(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_work_summary_with_until(db_session, sample_issues):
"""Test work summary with until date."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -553,7 +570,7 @@ async def test_github_work_summary_with_until(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_work_summary_with_repo_filter(db_session, sample_issues):
"""Test work summary filtered by repository."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -571,7 +588,7 @@ async def test_github_work_summary_with_repo_filter(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_work_summary_invalid_group_by(db_session, sample_issues):
"""Test work summary with invalid group_by value."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -584,7 +601,7 @@ async def test_github_work_summary_invalid_group_by(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_work_summary_includes_sample_issues(db_session, sample_issues):
"""Test that work summary includes sample issues."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -610,7 +627,7 @@ async def test_github_work_summary_includes_sample_issues(db_session, sample_iss
@pytest.mark.asyncio
async def test_github_repo_overview_basic(db_session, sample_issues):
"""Test basic repo overview."""
from memory.api.MCP.github import github_repo_overview
from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1")
@ -625,7 +642,7 @@ async def test_github_repo_overview_basic(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_repo_overview_counts(db_session, sample_issues):
"""Test repo overview counts are correct."""
from memory.api.MCP.github import github_repo_overview
from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1")
@ -638,7 +655,7 @@ async def test_github_repo_overview_counts(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_repo_overview_status_breakdown(db_session, sample_issues):
"""Test repo overview includes status breakdown."""
from memory.api.MCP.github import github_repo_overview
from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1")
@ -651,7 +668,7 @@ async def test_github_repo_overview_status_breakdown(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_repo_overview_top_assignees(db_session, sample_issues):
"""Test repo overview includes top assignees."""
from memory.api.MCP.github import github_repo_overview
from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1")
@ -663,7 +680,7 @@ async def test_github_repo_overview_top_assignees(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_repo_overview_labels(db_session, sample_issues):
"""Test repo overview includes labels."""
from memory.api.MCP.github import github_repo_overview
from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1")
@ -676,7 +693,7 @@ async def test_github_repo_overview_labels(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_repo_overview_last_updated(db_session, sample_issues):
"""Test repo overview includes last updated timestamp."""
from memory.api.MCP.github import github_repo_overview
from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="owner/repo1")
@ -688,7 +705,7 @@ async def test_github_repo_overview_last_updated(db_session, sample_issues):
@pytest.mark.asyncio
async def test_github_repo_overview_empty_repo(db_session):
"""Test repo overview for a repo with no issues."""
from memory.api.MCP.github import github_repo_overview
from memory.api.MCP.servers.github import github_repo_overview
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_repo_overview(repo="nonexistent/repo")
@ -704,7 +721,7 @@ async def test_github_repo_overview_empty_repo(db_session):
@pytest.mark.asyncio
async def test_search_github_issues_basic(db_session, sample_issues):
"""Test basic search functionality."""
from memory.api.MCP.github import search_github_issues
from memory.api.MCP.servers.github import search_github_issues
mock_search_result = Mock()
mock_search_result.id = sample_issues[0].id
@ -723,7 +740,7 @@ async def test_search_github_issues_basic(db_session, sample_issues):
@pytest.mark.asyncio
async def test_search_github_issues_with_repo_filter(db_session, sample_issues):
"""Test search with repository filter."""
from memory.api.MCP.github import search_github_issues
from memory.api.MCP.servers.github import search_github_issues
mock_search_result = Mock()
mock_search_result.id = sample_issues[0].id
@ -745,7 +762,7 @@ async def test_search_github_issues_with_repo_filter(db_session, sample_issues):
@pytest.mark.asyncio
async def test_search_github_issues_with_state_filter(db_session, sample_issues):
"""Test search with state filter."""
from memory.api.MCP.github import search_github_issues
from memory.api.MCP.servers.github import search_github_issues
mock_search_result = Mock()
mock_search_result.id = sample_issues[0].id
@ -762,7 +779,7 @@ async def test_search_github_issues_with_state_filter(db_session, sample_issues)
@pytest.mark.asyncio
async def test_search_github_issues_limit(db_session, sample_issues):
"""Test search respects limit."""
from memory.api.MCP.github import search_github_issues
from memory.api.MCP.servers.github import search_github_issues
mock_results = [Mock(id=issue.id, score=0.9 - i * 0.1) for i, issue in enumerate(sample_issues[:3])]
@ -779,7 +796,7 @@ async def test_search_github_issues_limit(db_session, sample_issues):
@pytest.mark.asyncio
async def test_search_github_issues_uses_github_modality(db_session, sample_issues):
"""Test that search uses github modality."""
from memory.api.MCP.github import search_github_issues
from memory.api.MCP.servers.github import search_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search:
@ -797,7 +814,7 @@ async def test_search_github_issues_uses_github_modality(db_session, sample_issu
def test_build_github_url_issue():
"""Test URL construction for issues."""
from memory.api.MCP.github import _build_github_url
from memory.api.MCP.servers.github import _build_github_url
url = _build_github_url("owner/repo", 123, "issue")
assert url == "https://github.com/owner/repo/issues/123"
@ -805,7 +822,7 @@ def test_build_github_url_issue():
def test_build_github_url_pr():
"""Test URL construction for PRs."""
from memory.api.MCP.github import _build_github_url
from memory.api.MCP.servers.github import _build_github_url
url = _build_github_url("owner/repo", 456, "pr")
assert url == "https://github.com/owner/repo/pull/456"
@ -813,7 +830,7 @@ def test_build_github_url_pr():
def test_build_github_url_no_number():
"""Test URL construction without number."""
from memory.api.MCP.github import _build_github_url
from memory.api.MCP.servers.github import _build_github_url
url = _build_github_url("owner/repo", None, "issue")
assert url == "https://github.com/owner/repo"
@ -821,7 +838,7 @@ def test_build_github_url_no_number():
def test_serialize_issue_basic(db_session, sample_issues):
"""Test issue serialization."""
from memory.api.MCP.github import _serialize_issue
from memory.api.MCP.servers.github import _serialize_issue
issue = sample_issues[0]
result = _serialize_issue(issue)
@ -838,7 +855,7 @@ def test_serialize_issue_basic(db_session, sample_issues):
def test_serialize_issue_with_content(db_session, sample_issues):
"""Test issue serialization with content."""
from memory.api.MCP.github import _serialize_issue
from memory.api.MCP.servers.github import _serialize_issue
issue = sample_issues[0]
result = _serialize_issue(issue, include_content=True)
@ -865,7 +882,7 @@ async def test_list_github_issues_ordering(
db_session, sample_issues, order_by, expected_first_number
):
"""Test different ordering options."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(order_by=order_by)
@ -880,7 +897,7 @@ async def test_list_github_issues_ordering(
)
async def test_github_work_summary_all_group_by_options(db_session, sample_issues, group_by):
"""Test all valid group_by options."""
from memory.api.MCP.github import github_work_summary
from memory.api.MCP.servers.github import github_work_summary
now = datetime.now(timezone.utc)
since = (now - timedelta(days=30)).isoformat()
@ -906,7 +923,7 @@ async def test_list_github_issues_label_filtering(
db_session, sample_issues, labels, expected_count
):
"""Test various label filtering scenarios."""
from memory.api.MCP.github import list_github_issues
from memory.api.MCP.servers.github import list_github_issues
with patch("memory.api.MCP.github.make_session", return_value=db_session):
results = await list_github_issues(labels=labels)
@ -1014,7 +1031,7 @@ def sample_pr_with_data(db_session):
@pytest.mark.asyncio
async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_data):
"""Test that github_issue_details includes PR data for PRs."""
from memory.api.MCP.github import github_issue_details
from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_issue_details(repo="owner/repo1", number=999)
@ -1034,7 +1051,7 @@ async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_
@pytest.mark.asyncio
async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_issues):
"""Test that github_issue_details does not include pr_data for issues."""
from memory.api.MCP.github import github_issue_details
from memory.api.MCP.servers.github import github_issue_details
with patch("memory.api.MCP.github.make_session", return_value=db_session):
result = await github_issue_details(repo="owner/repo1", number=1)
@ -1045,7 +1062,7 @@ async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_iss
def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data):
"""Test that _serialize_issue includes pr_data when include_content=True."""
from memory.api.MCP.github import _serialize_issue
from memory.api.MCP.servers.github import _serialize_issue
result = _serialize_issue(sample_pr_with_data, include_content=True)
@ -1056,7 +1073,7 @@ def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data):
def test_serialize_issue_no_pr_data_without_content(db_session, sample_pr_with_data):
"""Test that _serialize_issue excludes pr_data when include_content=False."""
from memory.api.MCP.github import _serialize_issue
from memory.api.MCP.servers.github import _serialize_issue
result = _serialize_issue(sample_pr_with_data, include_content=False)

View File

@ -4,10 +4,27 @@ import sys
import pytest
from unittest.mock import AsyncMock, Mock, MagicMock, patch
# Mock the mcp module and all its submodules before importing anything that uses it
# Mock FastMCP - this creates a decorator factory that passes through the function unchanged
class MockFastMCP:
def __init__(self, name):
self.name = name
def tool(self):
def decorator(func):
return func
return decorator
# Mock the fastmcp module before importing anything that uses it
_mock_fastmcp = MagicMock()
_mock_fastmcp.FastMCP = MockFastMCP
sys.modules["fastmcp"] = _mock_fastmcp
# Mock the mcp module and all its submodules
_mock_mcp = MagicMock()
_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
_mock_mcp.tool = lambda: lambda f: f
sys.modules["mcp"] = _mock_mcp
sys.modules["mcp.types"] = MagicMock()
sys.modules["mcp.server"] = MagicMock()
sys.modules["mcp.server.auth"] = MagicMock()
sys.modules["mcp.server.auth.handlers"] = MagicMock()
@ -20,7 +37,7 @@ sys.modules["mcp.server.fastmcp.server"] = MagicMock()
# Also mock the memory.api.MCP.base module to avoid MCP imports
_mock_base = MagicMock()
_mock_base.mcp = MagicMock()
_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator
_mock_base.mcp.tool = lambda: lambda f: f
sys.modules["memory.api.MCP.base"] = _mock_base
from memory.common.db.models import Person
@ -97,7 +114,7 @@ def sample_people(db_session):
@pytest.mark.asyncio
async def test_add_person_success(db_session):
"""Test adding a new person."""
from memory.api.MCP.people import add_person
from memory.api.MCP.servers.people import add_person
mock_task = Mock()
mock_task.id = "task-123"
@ -128,7 +145,7 @@ async def test_add_person_success(db_session):
@pytest.mark.asyncio
async def test_add_person_already_exists(db_session, sample_people):
"""Test adding a person that already exists."""
from memory.api.MCP.people import add_person
from memory.api.MCP.servers.people import add_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="already exists"):
@ -141,7 +158,7 @@ async def test_add_person_already_exists(db_session, sample_people):
@pytest.mark.asyncio
async def test_add_person_minimal(db_session):
"""Test adding a person with minimal data."""
from memory.api.MCP.people import add_person
from memory.api.MCP.servers.people import add_person
mock_task = Mock()
mock_task.id = "task-456"
@ -166,7 +183,7 @@ async def test_add_person_minimal(db_session):
@pytest.mark.asyncio
async def test_update_person_info_success(db_session, sample_people):
"""Test updating a person's info."""
from memory.api.MCP.people import update_person_info
from memory.api.MCP.servers.people import update_person_info
mock_task = Mock()
mock_task.id = "task-789"
@ -188,7 +205,7 @@ async def test_update_person_info_success(db_session, sample_people):
@pytest.mark.asyncio
async def test_update_person_info_not_found(db_session, sample_people):
"""Test updating a person that doesn't exist."""
from memory.api.MCP.people import update_person_info
from memory.api.MCP.servers.people import update_person_info
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="not found"):
@ -201,7 +218,7 @@ async def test_update_person_info_not_found(db_session, sample_people):
@pytest.mark.asyncio
async def test_update_person_info_with_merge_params(db_session, sample_people):
"""Test that update passes all merge parameters."""
from memory.api.MCP.people import update_person_info
from memory.api.MCP.servers.people import update_person_info
mock_task = Mock()
mock_task.id = "task-merge"
@ -234,7 +251,7 @@ async def test_update_person_info_with_merge_params(db_session, sample_people):
@pytest.mark.asyncio
async def test_get_person_found(db_session, sample_people):
"""Test getting a person that exists."""
from memory.api.MCP.people import get_person
from memory.api.MCP.servers.people import get_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await get_person(identifier="alice_chen")
@ -251,7 +268,7 @@ async def test_get_person_found(db_session, sample_people):
@pytest.mark.asyncio
async def test_get_person_not_found(db_session, sample_people):
"""Test getting a person that doesn't exist."""
from memory.api.MCP.people import get_person
from memory.api.MCP.servers.people import get_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await get_person(identifier="nonexistent_person")
@ -267,7 +284,7 @@ async def test_get_person_not_found(db_session, sample_people):
@pytest.mark.asyncio
async def test_list_people_no_filters(db_session, sample_people):
"""Test listing all people without filters."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people()
@ -282,7 +299,7 @@ async def test_list_people_no_filters(db_session, sample_people):
@pytest.mark.asyncio
async def test_list_people_filter_by_tags(db_session, sample_people):
"""Test filtering by tags."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=["work"])
@ -294,7 +311,7 @@ async def test_list_people_filter_by_tags(db_session, sample_people):
@pytest.mark.asyncio
async def test_list_people_filter_by_search(db_session, sample_people):
"""Test filtering by search term."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search="alice")
@ -306,7 +323,7 @@ async def test_list_people_filter_by_search(db_session, sample_people):
@pytest.mark.asyncio
async def test_list_people_search_in_notes(db_session, sample_people):
"""Test that search works on notes content."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search="climbing")
@ -318,7 +335,7 @@ async def test_list_people_search_in_notes(db_session, sample_people):
@pytest.mark.asyncio
async def test_list_people_limit(db_session, sample_people):
"""Test limiting results."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(limit=1)
@ -329,7 +346,7 @@ async def test_list_people_limit(db_session, sample_people):
@pytest.mark.asyncio
async def test_list_people_limit_max_enforced(db_session, sample_people):
"""Test that limit is capped at 200."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
# Request 500 but should be capped at 200
@ -342,7 +359,7 @@ async def test_list_people_limit_max_enforced(db_session, sample_people):
@pytest.mark.asyncio
async def test_list_people_combined_filters(db_session, sample_people):
"""Test combining tag and search filters."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=["work"], search="chen")
@ -359,7 +376,7 @@ async def test_list_people_combined_filters(db_session, sample_people):
@pytest.mark.asyncio
async def test_delete_person_success(db_session, sample_people):
"""Test deleting a person."""
from memory.api.MCP.people import delete_person
from memory.api.MCP.servers.people import delete_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
result = await delete_person(identifier="bob_smith")
@ -376,7 +393,7 @@ async def test_delete_person_success(db_session, sample_people):
@pytest.mark.asyncio
async def test_delete_person_not_found(db_session, sample_people):
"""Test deleting a person that doesn't exist."""
from memory.api.MCP.people import delete_person
from memory.api.MCP.servers.people import delete_person
with patch("memory.api.MCP.people.make_session", return_value=db_session):
with pytest.raises(ValueError, match="not found"):
@ -390,7 +407,7 @@ async def test_delete_person_not_found(db_session, sample_people):
def test_person_to_dict(sample_people):
"""Test the _person_to_dict helper function."""
from memory.api.MCP.people import _person_to_dict
from memory.api.MCP.servers.people import _person_to_dict
person = sample_people[0]
result = _person_to_dict(person)
@ -406,7 +423,7 @@ def test_person_to_dict(sample_people):
def test_person_to_dict_empty_fields(db_session):
"""Test _person_to_dict with empty optional fields."""
from memory.api.MCP.people import _person_to_dict
from memory.api.MCP.servers.people import _person_to_dict
person = Person(
identifier="empty_person",
@ -448,7 +465,7 @@ def test_person_to_dict_empty_fields(db_session):
)
async def test_list_people_various_tags(db_session, sample_people, tag, expected_count):
"""Test filtering by various tags."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(tags=[tag])
@ -473,7 +490,7 @@ async def test_list_people_various_searches(
db_session, sample_people, search_term, expected_identifiers
):
"""Test search with various terms."""
from memory.api.MCP.people import list_people
from memory.api.MCP.servers.people import list_people
with patch("memory.api.MCP.people.make_session", return_value=db_session):
results = await list_people(search=search_term)

View File

@ -69,8 +69,12 @@ class TestSearchResult:
result = SearchResult.from_source_item(source, chunks)
assert result.search_score == 0.9
def test_search_score_multiple_chunks_uses_mean(self):
"""Multiple chunks should use mean of relevance scores, not sum."""
def test_search_score_multiple_chunks_uses_max(self):
"""Multiple chunks should use max of relevance scores.
Using max finds documents with at least one highly relevant section,
which is better for 'half-remembered' searches where users recall one detail.
"""
source = self._make_source_item()
chunks = [
self._make_chunk(0.9),
@ -79,8 +83,8 @@ class TestSearchResult:
]
result = SearchResult.from_source_item(source, chunks)
# Mean of 0.9, 0.7, 0.8 = 0.8
assert result.search_score == pytest.approx(0.8)
# Max of 0.9, 0.7, 0.8 = 0.9
assert result.search_score == pytest.approx(0.9)
def test_search_score_empty_chunks(self):
"""Empty chunk list should result in None or 0 score."""