From 9cf71c9336b69267d84aa15e2243dadb2826e97b Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 29 Dec 2025 17:03:08 +0100 Subject: [PATCH] exclude stuff --- .../20251229_150000_add_exclude_folder_ids.py | 35 +++ frontend/src/App.css | 94 ++++++++ frontend/src/components/sources/Sources.tsx | 220 +++++++++++++++++- frontend/src/hooks/useSources.ts | 2 + src/memory/api/admin.py | 42 ++++ src/memory/api/google_drive.py | 7 + src/memory/common/db/models/sources.py | 3 + src/memory/parsers/google_drive.py | 21 +- src/memory/workers/tasks/google_drive.py | 6 + tests/memory/api/MCP/test_github.py | 125 +++++----- tests/memory/api/MCP/test_people.py | 65 ++++-- tests/memory/api/search/test_types.py | 12 +- 12 files changed, 546 insertions(+), 86 deletions(-) create mode 100644 db/migrations/versions/20251229_150000_add_exclude_folder_ids.py diff --git a/db/migrations/versions/20251229_150000_add_exclude_folder_ids.py b/db/migrations/versions/20251229_150000_add_exclude_folder_ids.py new file mode 100644 index 0000000..895510c --- /dev/null +++ b/db/migrations/versions/20251229_150000_add_exclude_folder_ids.py @@ -0,0 +1,35 @@ +"""Add exclude_folder_ids to google_folders + +Revision ID: add_exclude_folder_ids +Revises: 20251229_120000_add_google_drive +Create Date: 2025-12-29 15:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "add_exclude_folder_ids" +down_revision: Union[str, None] = "f2a3b4c5d6e7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "google_folders", + sa.Column( + "exclude_folder_ids", + postgresql.ARRAY(sa.Text()), + nullable=False, + server_default="{}", + ), + ) + + +def downgrade() -> None: + op.drop_column("google_folders", "exclude_folder_ids") diff --git a/frontend/src/App.css b/frontend/src/App.css index 64f0dc0..7359d3e 100644 --- a/frontend/src/App.css +++ b/frontend/src/App.css @@ -1935,4 +1935,98 @@ a.folder-item-name:hover { .add-btn.secondary:hover { background: #edf2f7; +} + +/* === Exclusion Browser === */ + +.exclusion-browser { + display: flex; + flex-direction: column; + min-height: 400px; + max-height: 60vh; +} + +.excluded-list { + background: #fffbeb; + border: 1px solid #f59e0b; + border-radius: 6px; + padding: 0.75rem; + margin-bottom: 0.75rem; +} + +.excluded-list h5 { + font-size: 0.85rem; + color: #92400e; + margin-bottom: 0.5rem; +} + +.excluded-items { + display: flex; + flex-wrap: wrap; + gap: 0.5rem; +} + +.excluded-item { + display: flex; + align-items: center; + gap: 0.25rem; + background: #fef3c7; + padding: 0.25rem 0.5rem; + border-radius: 4px; + font-size: 0.85rem; +} + +.excluded-name { + color: #92400e; +} + +.remove-exclusion-btn { + background: none; + border: none; + color: #b45309; + cursor: pointer; + padding: 0 0.25rem; + font-size: 0.75rem; + line-height: 1; +} + +.remove-exclusion-btn:hover { + color: #dc2626; +} + +.folder-item.excluded { + background: #fffbeb; + border-color: #f59e0b; +} + +.exclusion-badge { + background: #f59e0b; + color: white; + padding: 0.125rem 0.375rem; + border-radius: 3px; + font-size: 0.7rem; + margin-left: auto; + margin-right: 0.5rem; +} + +.exclusions-btn { + background: #fffbeb; + color: #b45309; + border: 1px solid #f59e0b; + padding: 0.25rem 0.5rem; + border-radius: 4px; + font-size: 0.75rem; + cursor: pointer; + transition: all 0.15s ease; +} + +.exclusions-btn:hover { + background: #fef3c7; + color: #92400e; +} + +.setting-badge.warning { + background: #fffbeb; + color: #b45309; + border: 1px solid #f59e0b; } \ No newline at end of file diff --git a/frontend/src/components/sources/Sources.tsx b/frontend/src/components/sources/Sources.tsx index 78c0230..b56521e 100644 --- a/frontend/src/components/sources/Sources.tsx +++ b/frontend/src/components/sources/Sources.tsx @@ -1,6 +1,6 @@ import { useState, useEffect, useCallback } from 'react' import { Link } from 'react-router-dom' -import { useSources, EmailAccount, ArticleFeed, GithubAccount, GoogleAccount, GoogleOAuthConfig, DriveItem, BrowseResponse, GoogleFolderCreate } from '@/hooks/useSources' +import { useSources, EmailAccount, ArticleFeed, GithubAccount, GoogleAccount, GoogleFolder, GoogleOAuthConfig, DriveItem, BrowseResponse, GoogleFolderCreate } from '@/hooks/useSources' import { SourceCard, Modal, @@ -956,6 +956,7 @@ const GoogleDrivePanel = () => { const [error, setError] = useState(null) const [addingFolderTo, setAddingFolderTo] = useState(null) const [browsingFoldersFor, setBrowsingFoldersFor] = useState(null) + const [managingExclusionsFor, setManagingExclusionsFor] = useState<{ accountId: number; folder: GoogleFolder } | null>(null) const [uploadingConfig, setUploadingConfig] = useState(false) const loadData = useCallback(async () => { @@ -1044,6 +1045,12 @@ const GoogleDrivePanel = () => { loadData() } + const handleUpdateExclusions = async (accountId: number, folderId: number, excludeIds: string[]) => { + await updateGoogleFolder(accountId, folderId, { exclude_folder_ids: excludeIds }) + setManagingExclusionsFor(null) + loadData() + } + const handleDeleteFolder = async (accountId: number, folderId: number) => { await deleteGoogleFolder(accountId, folderId) loadData() @@ -1188,6 +1195,11 @@ const GoogleDrivePanel = () => {
{folder.recursive && Recursive} {folder.include_shared && Shared} + {folder.exclude_folder_ids.length > 0 && ( + + {folder.exclude_folder_ids.length} excluded + + )}
{ disabled={!folder.active || !account.active} label="Sync" /> + {folder.recursive && ( + + )}
) } @@ -1422,6 +1451,195 @@ const FolderBrowser = ({ accountId, onSelect, onCancel }: FolderBrowserProps) => ) } +// === Exclusion Browser === + +interface ExclusionBrowserProps { + accountId: number + folder: GoogleFolder + onSave: (excludeIds: string[]) => void + onCancel: () => void +} + +interface ExcludedFolder { + id: string + name: string + path: string +} + +const ExclusionBrowser = ({ accountId, folder, onSave, onCancel }: ExclusionBrowserProps) => { + const { browseGoogleDrive } = useSources() + const [path, setPath] = useState([{ id: folder.folder_id, name: folder.folder_name }]) + const [items, setItems] = useState([]) + const [excluded, setExcluded] = useState>(() => { + // Initialize with current exclusions (we don't have names, so use ID as name) + const map = new Map() + for (const id of folder.exclude_folder_ids) { + map.set(id, { id, name: id, path: '(previously excluded)' }) + } + return map + }) + const [loading, setLoading] = useState(true) + const [error, setError] = useState(null) + + const currentFolderId = path[path.length - 1].id + const currentPath = path.map(p => p.name).join(' > ') + + const loadFolder = useCallback(async (folderId: string) => { + setLoading(true) + setError(null) + try { + const response = await browseGoogleDrive(accountId, folderId) + // Only show folders, not files + setItems(response.items.filter(item => item.is_folder)) + } catch (e) { + setError(e instanceof Error ? e.message : 'Failed to load folder') + } finally { + setLoading(false) + } + }, [accountId, browseGoogleDrive]) + + useEffect(() => { + loadFolder(currentFolderId) + }, [currentFolderId, loadFolder]) + + const navigateToFolder = (item: DriveItem) => { + setPath([...path, { id: item.id, name: item.name }]) + } + + const navigateToPathIndex = (index: number) => { + setPath(path.slice(0, index + 1)) + } + + const toggleExclude = (item: DriveItem) => { + const newExcluded = new Map(excluded) + if (newExcluded.has(item.id)) { + newExcluded.delete(item.id) + } else { + newExcluded.set(item.id, { + id: item.id, + name: item.name, + path: currentPath + ' > ' + item.name, + }) + } + setExcluded(newExcluded) + } + + const removeExclusion = (id: string) => { + const newExcluded = new Map(excluded) + newExcluded.delete(id) + setExcluded(newExcluded) + } + + const handleSave = () => { + onSave(Array.from(excluded.keys())) + } + + return ( + +
+ {/* Current exclusions */} + {excluded.size > 0 && ( +
+
Excluded Folders ({excluded.size})
+
+ {Array.from(excluded.values()).map(item => ( +
+ + 📁 {item.name} + + +
+ ))} +
+
+ )} + + {/* Breadcrumb */} +
+ {path.map((item, index) => ( + + {index > 0 && >} + + + ))} +
+ + {/* Content */} + {error &&
{error}
} + + {loading ? ( +
Loading...
+ ) : items.length === 0 ? ( +
No subfolders in this folder
+ ) : ( +
+ {items.map(item => ( +
+ + 📁 + e.stopPropagation()} + > + {item.name} + + {excluded.has(item.id) && ( + Excluded + )} + +
+ ))} +
+ )} + + {/* Footer */} +
+ + {excluded.size} folder{excluded.size !== 1 ? 's' : ''} excluded + +
+ + +
+
+
+
+ ) +} + interface GoogleFolderFormProps { accountId: number onSubmit: (data: any) => Promise diff --git a/frontend/src/hooks/useSources.ts b/frontend/src/hooks/useSources.ts index aab3ee2..90f14c5 100644 --- a/frontend/src/hooks/useSources.ts +++ b/frontend/src/hooks/useSources.ts @@ -177,6 +177,7 @@ export interface GoogleFolder { check_interval: number last_sync_at: string | null active: boolean + exclude_folder_ids: string[] } export interface GoogleAccount { @@ -205,6 +206,7 @@ export interface GoogleFolderUpdate { tags?: string[] check_interval?: number active?: boolean + exclude_folder_ids?: string[] } // Types for Google Drive browsing diff --git a/src/memory/api/admin.py b/src/memory/api/admin.py index af15ff1..fd46a69 100644 --- a/src/memory/api/admin.py +++ b/src/memory/api/admin.py @@ -29,6 +29,9 @@ from memory.common.db.models import ( ScheduledLLMCall, SourceItem, User, + GoogleDoc, + GoogleFolder, + GoogleAccount, ) from memory.common.db.models.discord import DiscordChannel, DiscordServer, DiscordUser @@ -394,6 +397,42 @@ class GithubItemAdmin(ModelView, model=GithubItem): column_sortable_list = ["github_updated_at", "created_at"] +class GoogleDocAdmin(ModelView, model=GoogleDoc): + column_list = source_columns( + GoogleDoc, + "title", + "folder_path", + "owner", + "last_modified_by", + "word_count", + "content_hash", + ) + column_searchable_list = ["title", "folder_path", "owner", "last_modified_by", "id"] + column_sortable_list = ["google_modified_at", "created_at"] + + +class GoogleFolderAdmin(ModelView, model=GoogleFolder): + column_list = source_columns( + GoogleFolder, "folder_name", "folder_path", "account", "active" + ) + column_searchable_list = ["folder_name", "folder_path", "id"] + column_sortable_list = ["last_sync_at", "created_at"] + + +class GoogleAccountAdmin(ModelView, model=GoogleAccount): + column_list = [ + "id", + "name", + "email", + "active", + "last_sync_at", + "created_at", + "updated_at", + ] + column_searchable_list = ["name", "email", "id"] + column_sortable_list = ["last_sync_at", "created_at"] + + def setup_admin(admin: Admin): """Add all admin views to the admin instance with OAuth protection.""" admin.add_view(SourceItemAdmin) @@ -421,3 +460,6 @@ def setup_admin(admin: Admin): admin.add_view(GithubAccountAdmin) admin.add_view(GithubRepoAdmin) admin.add_view(GithubItemAdmin) + admin.add_view(GoogleDocAdmin) + admin.add_view(GoogleFolderAdmin) + admin.add_view(GoogleAccountAdmin) diff --git a/src/memory/api/google_drive.py b/src/memory/api/google_drive.py index b154f1a..57ad533 100644 --- a/src/memory/api/google_drive.py +++ b/src/memory/api/google_drive.py @@ -60,6 +60,7 @@ class FolderUpdate(BaseModel): tags: list[str] | None = None check_interval: int | None = None active: bool | None = None + exclude_folder_ids: list[str] | None = None class FolderResponse(BaseModel): @@ -73,6 +74,7 @@ class FolderResponse(BaseModel): check_interval: int last_sync_at: str | None active: bool + exclude_folder_ids: list[str] class AccountResponse(BaseModel): @@ -370,6 +372,7 @@ def list_accounts( folder.last_sync_at.isoformat() if folder.last_sync_at else None ), active=cast(bool, folder.active), + exclude_folder_ids=cast(list[str], folder.exclude_folder_ids) or [], ) for folder in account.folders ], @@ -531,6 +534,7 @@ def add_folder( check_interval=cast(int, new_folder.check_interval), last_sync_at=None, active=cast(bool, new_folder.active), + exclude_folder_ids=[], ) @@ -567,6 +571,8 @@ def update_folder( folder.check_interval = updates.check_interval if updates.active is not None: folder.active = updates.active + if updates.exclude_folder_ids is not None: + folder.exclude_folder_ids = updates.exclude_folder_ids db.commit() db.refresh(folder) @@ -584,6 +590,7 @@ def update_folder( folder.last_sync_at.isoformat() if folder.last_sync_at else None ), active=cast(bool, folder.active), + exclude_folder_ids=cast(list[str], folder.exclude_folder_ids) or [], ) diff --git a/src/memory/common/db/models/sources.py b/src/memory/common/db/models/sources.py index f95c5cf..096179b 100644 --- a/src/memory/common/db/models/sources.py +++ b/src/memory/common/db/models/sources.py @@ -350,6 +350,9 @@ class GoogleFolder(Base): # File type filters (empty = all text documents) mime_type_filter = Column(ARRAY(Text), nullable=False, server_default="{}") + # Excluded subfolder IDs (skip these when syncing recursively) + exclude_folder_ids = Column(ARRAY(Text), nullable=False, server_default="{}") + # Tags to apply to all documents from this folder tags = Column(ARRAY(Text), nullable=False, server_default="{}") diff --git a/src/memory/parsers/google_drive.py b/src/memory/parsers/google_drive.py index f05ea1d..49c9c75 100644 --- a/src/memory/parsers/google_drive.py +++ b/src/memory/parsers/google_drive.py @@ -130,9 +130,19 @@ class GoogleDriveClient: recursive: bool = True, since: datetime | None = None, page_size: int = 100, + exclude_folder_ids: set[str] | None = None, ) -> Generator[dict, None, None]: - """List all supported files in a folder with pagination.""" + """List all supported files in a folder with pagination. + + Args: + folder_id: The Google Drive folder ID to list + recursive: Whether to recurse into subfolders + since: Only return files modified after this time + page_size: Number of files per API page + exclude_folder_ids: Set of folder IDs to skip during recursive traversal + """ service = self._get_service() + exclude_folder_ids = exclude_folder_ids or set() # Build query for supported file types all_mimes = SUPPORTED_GOOGLE_MIMES | SUPPORTED_FILE_MIMES @@ -167,11 +177,16 @@ class GoogleDriveClient: for file in response.get("files", []): if file["mimeType"] == "application/vnd.google-apps.folder": - if recursive: + if recursive and file["id"] not in exclude_folder_ids: # Recursively list files in subfolder yield from self.list_files_in_folder( - file["id"], recursive=True, since=since + file["id"], + recursive=True, + since=since, + exclude_folder_ids=exclude_folder_ids, ) + elif file["id"] in exclude_folder_ids: + logger.info(f"Skipping excluded folder: {file['name']} ({file['id']})") else: yield file diff --git a/src/memory/workers/tasks/google_drive.py b/src/memory/workers/tasks/google_drive.py index 0fb30b0..1f86b78 100644 --- a/src/memory/workers/tasks/google_drive.py +++ b/src/memory/workers/tasks/google_drive.py @@ -251,10 +251,16 @@ def sync_google_folder(folder_id: int, force_full: bool = False) -> dict[str, An # It's a folder - list and sync all files inside folder_path = client.get_folder_path(google_id) + # Get excluded folder IDs + exclude_ids = set(cast(list[str], folder.exclude_folder_ids) or []) + if exclude_ids: + logger.info(f"Excluding {len(exclude_ids)} folder(s) from sync") + for file_meta in client.list_files_in_folder( google_id, recursive=cast(bool, folder.recursive), since=since, + exclude_folder_ids=exclude_ids, ): try: file_data = client.fetch_file(file_meta, folder_path) diff --git a/tests/memory/api/MCP/test_github.py b/tests/memory/api/MCP/test_github.py index ee42f6b..2a2d3e0 100644 --- a/tests/memory/api/MCP/test_github.py +++ b/tests/memory/api/MCP/test_github.py @@ -5,10 +5,27 @@ import pytest from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, Mock, MagicMock, patch -# Mock the mcp module and all its submodules before importing anything that uses it +# Mock FastMCP - this creates a decorator factory that passes through the function unchanged +class MockFastMCP: + def __init__(self, name): + self.name = name + + def tool(self): + def decorator(func): + return func + return decorator + + +# Mock the fastmcp module before importing anything that uses it +_mock_fastmcp = MagicMock() +_mock_fastmcp.FastMCP = MockFastMCP +sys.modules["fastmcp"] = _mock_fastmcp + +# Mock the mcp module and all its submodules _mock_mcp = MagicMock() -_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator +_mock_mcp.tool = lambda: lambda f: f sys.modules["mcp"] = _mock_mcp +sys.modules["mcp.types"] = MagicMock() sys.modules["mcp.server"] = MagicMock() sys.modules["mcp.server.auth"] = MagicMock() sys.modules["mcp.server.auth.handlers"] = MagicMock() @@ -21,7 +38,7 @@ sys.modules["mcp.server.fastmcp.server"] = MagicMock() # Also mock the memory.api.MCP.base module to avoid MCP imports _mock_base = MagicMock() _mock_base.mcp = MagicMock() -_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator +_mock_base.mcp.tool = lambda: lambda f: f sys.modules["memory.api.MCP.base"] = _mock_base from memory.common.db.models import GithubItem @@ -189,7 +206,7 @@ def sample_issues(db_session): @pytest.mark.asyncio async def test_list_github_issues_no_filters(db_session, sample_issues): """Test listing all issues without filters.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues() @@ -203,7 +220,7 @@ async def test_list_github_issues_no_filters(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_filter_by_repo(db_session, sample_issues): """Test filtering by repository.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(repo="owner/repo1") @@ -215,7 +232,7 @@ async def test_list_github_issues_filter_by_repo(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_filter_by_assignee(db_session, sample_issues): """Test filtering by assignee.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(assignee="alice") @@ -227,7 +244,7 @@ async def test_list_github_issues_filter_by_assignee(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_filter_by_author(db_session, sample_issues): """Test filtering by author.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(author="alice") @@ -239,7 +256,7 @@ async def test_list_github_issues_filter_by_author(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_filter_by_state(db_session, sample_issues): """Test filtering by state.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): open_results = await list_github_issues(state="open") @@ -259,7 +276,7 @@ async def test_list_github_issues_filter_by_state(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_filter_by_kind(db_session, sample_issues): """Test filtering by kind (issue vs PR).""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): issues = await list_github_issues(kind="issue") @@ -275,7 +292,7 @@ async def test_list_github_issues_filter_by_kind(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_filter_by_labels(db_session, sample_issues): """Test filtering by labels.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(labels=["bug"]) @@ -287,7 +304,7 @@ async def test_list_github_issues_filter_by_labels(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_filter_by_project_status(db_session, sample_issues): """Test filtering by project status.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(project_status="In Progress") @@ -300,7 +317,7 @@ async def test_list_github_issues_filter_by_project_status(db_session, sample_is @pytest.mark.asyncio async def test_list_github_issues_filter_by_project_field(db_session, sample_issues): """Test filtering by project field (JSONB).""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues( @@ -316,7 +333,7 @@ async def test_list_github_issues_filter_by_project_field(db_session, sample_iss @pytest.mark.asyncio async def test_list_github_issues_filter_by_updated_since(db_session, sample_issues): """Test filtering by updated_since.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues now = datetime.now(timezone.utc) since = (now - timedelta(days=2)).isoformat() @@ -332,7 +349,7 @@ async def test_list_github_issues_filter_by_updated_since(db_session, sample_iss @pytest.mark.asyncio async def test_list_github_issues_filter_by_updated_before(db_session, sample_issues): """Test filtering by updated_before (stale issues).""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues now = datetime.now(timezone.utc) before = (now - timedelta(days=30)).isoformat() @@ -348,7 +365,7 @@ async def test_list_github_issues_filter_by_updated_before(db_session, sample_is @pytest.mark.asyncio async def test_list_github_issues_order_by_created(db_session, sample_issues): """Test ordering by created date.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(order_by="created") @@ -360,7 +377,7 @@ async def test_list_github_issues_order_by_created(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_limit(db_session, sample_issues): """Test limiting results.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(limit=2) @@ -371,7 +388,7 @@ async def test_list_github_issues_limit(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_limit_max_enforced(db_session, sample_issues): """Test that limit is capped at 200.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): # Request 500 but should be capped at 200 @@ -384,7 +401,7 @@ async def test_list_github_issues_limit_max_enforced(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_combined_filters(db_session, sample_issues): """Test combining multiple filters.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues( @@ -402,7 +419,7 @@ async def test_list_github_issues_combined_filters(db_session, sample_issues): @pytest.mark.asyncio async def test_list_github_issues_url_construction(db_session, sample_issues): """Test that URLs are correctly constructed.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(kind="issue", limit=1) @@ -425,7 +442,7 @@ async def test_list_github_issues_url_construction(db_session, sample_issues): @pytest.mark.asyncio async def test_github_issue_details_found(db_session, sample_issues): """Test getting details for an existing issue.""" - from memory.api.MCP.github import github_issue_details + from memory.api.MCP.servers.github import github_issue_details with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_issue_details(repo="owner/repo1", number=1) @@ -440,7 +457,7 @@ async def test_github_issue_details_found(db_session, sample_issues): @pytest.mark.asyncio async def test_github_issue_details_not_found(db_session, sample_issues): """Test getting details for a non-existent issue.""" - from memory.api.MCP.github import github_issue_details + from memory.api.MCP.servers.github import github_issue_details with patch("memory.api.MCP.github.make_session", return_value=db_session): with pytest.raises(ValueError, match="not found"): @@ -450,7 +467,7 @@ async def test_github_issue_details_not_found(db_session, sample_issues): @pytest.mark.asyncio async def test_github_issue_details_pr(db_session, sample_issues): """Test getting details for a PR.""" - from memory.api.MCP.github import github_issue_details + from memory.api.MCP.servers.github import github_issue_details with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_issue_details(repo="owner/repo1", number=50) @@ -468,7 +485,7 @@ async def test_github_issue_details_pr(db_session, sample_issues): @pytest.mark.asyncio async def test_github_work_summary_by_client(db_session, sample_issues): """Test work summary grouped by client.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -489,7 +506,7 @@ async def test_github_work_summary_by_client(db_session, sample_issues): @pytest.mark.asyncio async def test_github_work_summary_by_status(db_session, sample_issues): """Test work summary grouped by status.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -506,7 +523,7 @@ async def test_github_work_summary_by_status(db_session, sample_issues): @pytest.mark.asyncio async def test_github_work_summary_by_author(db_session, sample_issues): """Test work summary grouped by author.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -522,7 +539,7 @@ async def test_github_work_summary_by_author(db_session, sample_issues): @pytest.mark.asyncio async def test_github_work_summary_by_repo(db_session, sample_issues): """Test work summary grouped by repository.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -538,7 +555,7 @@ async def test_github_work_summary_by_repo(db_session, sample_issues): @pytest.mark.asyncio async def test_github_work_summary_with_until(db_session, sample_issues): """Test work summary with until date.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -553,7 +570,7 @@ async def test_github_work_summary_with_until(db_session, sample_issues): @pytest.mark.asyncio async def test_github_work_summary_with_repo_filter(db_session, sample_issues): """Test work summary filtered by repository.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -571,7 +588,7 @@ async def test_github_work_summary_with_repo_filter(db_session, sample_issues): @pytest.mark.asyncio async def test_github_work_summary_invalid_group_by(db_session, sample_issues): """Test work summary with invalid group_by value.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -584,7 +601,7 @@ async def test_github_work_summary_invalid_group_by(db_session, sample_issues): @pytest.mark.asyncio async def test_github_work_summary_includes_sample_issues(db_session, sample_issues): """Test that work summary includes sample issues.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -610,7 +627,7 @@ async def test_github_work_summary_includes_sample_issues(db_session, sample_iss @pytest.mark.asyncio async def test_github_repo_overview_basic(db_session, sample_issues): """Test basic repo overview.""" - from memory.api.MCP.github import github_repo_overview + from memory.api.MCP.servers.github import github_repo_overview with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_repo_overview(repo="owner/repo1") @@ -625,7 +642,7 @@ async def test_github_repo_overview_basic(db_session, sample_issues): @pytest.mark.asyncio async def test_github_repo_overview_counts(db_session, sample_issues): """Test repo overview counts are correct.""" - from memory.api.MCP.github import github_repo_overview + from memory.api.MCP.servers.github import github_repo_overview with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_repo_overview(repo="owner/repo1") @@ -638,7 +655,7 @@ async def test_github_repo_overview_counts(db_session, sample_issues): @pytest.mark.asyncio async def test_github_repo_overview_status_breakdown(db_session, sample_issues): """Test repo overview includes status breakdown.""" - from memory.api.MCP.github import github_repo_overview + from memory.api.MCP.servers.github import github_repo_overview with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_repo_overview(repo="owner/repo1") @@ -651,7 +668,7 @@ async def test_github_repo_overview_status_breakdown(db_session, sample_issues): @pytest.mark.asyncio async def test_github_repo_overview_top_assignees(db_session, sample_issues): """Test repo overview includes top assignees.""" - from memory.api.MCP.github import github_repo_overview + from memory.api.MCP.servers.github import github_repo_overview with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_repo_overview(repo="owner/repo1") @@ -663,7 +680,7 @@ async def test_github_repo_overview_top_assignees(db_session, sample_issues): @pytest.mark.asyncio async def test_github_repo_overview_labels(db_session, sample_issues): """Test repo overview includes labels.""" - from memory.api.MCP.github import github_repo_overview + from memory.api.MCP.servers.github import github_repo_overview with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_repo_overview(repo="owner/repo1") @@ -676,7 +693,7 @@ async def test_github_repo_overview_labels(db_session, sample_issues): @pytest.mark.asyncio async def test_github_repo_overview_last_updated(db_session, sample_issues): """Test repo overview includes last updated timestamp.""" - from memory.api.MCP.github import github_repo_overview + from memory.api.MCP.servers.github import github_repo_overview with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_repo_overview(repo="owner/repo1") @@ -688,7 +705,7 @@ async def test_github_repo_overview_last_updated(db_session, sample_issues): @pytest.mark.asyncio async def test_github_repo_overview_empty_repo(db_session): """Test repo overview for a repo with no issues.""" - from memory.api.MCP.github import github_repo_overview + from memory.api.MCP.servers.github import github_repo_overview with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_repo_overview(repo="nonexistent/repo") @@ -704,7 +721,7 @@ async def test_github_repo_overview_empty_repo(db_session): @pytest.mark.asyncio async def test_search_github_issues_basic(db_session, sample_issues): """Test basic search functionality.""" - from memory.api.MCP.github import search_github_issues + from memory.api.MCP.servers.github import search_github_issues mock_search_result = Mock() mock_search_result.id = sample_issues[0].id @@ -723,7 +740,7 @@ async def test_search_github_issues_basic(db_session, sample_issues): @pytest.mark.asyncio async def test_search_github_issues_with_repo_filter(db_session, sample_issues): """Test search with repository filter.""" - from memory.api.MCP.github import search_github_issues + from memory.api.MCP.servers.github import search_github_issues mock_search_result = Mock() mock_search_result.id = sample_issues[0].id @@ -745,7 +762,7 @@ async def test_search_github_issues_with_repo_filter(db_session, sample_issues): @pytest.mark.asyncio async def test_search_github_issues_with_state_filter(db_session, sample_issues): """Test search with state filter.""" - from memory.api.MCP.github import search_github_issues + from memory.api.MCP.servers.github import search_github_issues mock_search_result = Mock() mock_search_result.id = sample_issues[0].id @@ -762,7 +779,7 @@ async def test_search_github_issues_with_state_filter(db_session, sample_issues) @pytest.mark.asyncio async def test_search_github_issues_limit(db_session, sample_issues): """Test search respects limit.""" - from memory.api.MCP.github import search_github_issues + from memory.api.MCP.servers.github import search_github_issues mock_results = [Mock(id=issue.id, score=0.9 - i * 0.1) for i, issue in enumerate(sample_issues[:3])] @@ -779,7 +796,7 @@ async def test_search_github_issues_limit(db_session, sample_issues): @pytest.mark.asyncio async def test_search_github_issues_uses_github_modality(db_session, sample_issues): """Test that search uses github modality.""" - from memory.api.MCP.github import search_github_issues + from memory.api.MCP.servers.github import search_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): with patch("memory.api.MCP.github.search", new_callable=AsyncMock) as mock_search: @@ -797,7 +814,7 @@ async def test_search_github_issues_uses_github_modality(db_session, sample_issu def test_build_github_url_issue(): """Test URL construction for issues.""" - from memory.api.MCP.github import _build_github_url + from memory.api.MCP.servers.github import _build_github_url url = _build_github_url("owner/repo", 123, "issue") assert url == "https://github.com/owner/repo/issues/123" @@ -805,7 +822,7 @@ def test_build_github_url_issue(): def test_build_github_url_pr(): """Test URL construction for PRs.""" - from memory.api.MCP.github import _build_github_url + from memory.api.MCP.servers.github import _build_github_url url = _build_github_url("owner/repo", 456, "pr") assert url == "https://github.com/owner/repo/pull/456" @@ -813,7 +830,7 @@ def test_build_github_url_pr(): def test_build_github_url_no_number(): """Test URL construction without number.""" - from memory.api.MCP.github import _build_github_url + from memory.api.MCP.servers.github import _build_github_url url = _build_github_url("owner/repo", None, "issue") assert url == "https://github.com/owner/repo" @@ -821,7 +838,7 @@ def test_build_github_url_no_number(): def test_serialize_issue_basic(db_session, sample_issues): """Test issue serialization.""" - from memory.api.MCP.github import _serialize_issue + from memory.api.MCP.servers.github import _serialize_issue issue = sample_issues[0] result = _serialize_issue(issue) @@ -838,7 +855,7 @@ def test_serialize_issue_basic(db_session, sample_issues): def test_serialize_issue_with_content(db_session, sample_issues): """Test issue serialization with content.""" - from memory.api.MCP.github import _serialize_issue + from memory.api.MCP.servers.github import _serialize_issue issue = sample_issues[0] result = _serialize_issue(issue, include_content=True) @@ -865,7 +882,7 @@ async def test_list_github_issues_ordering( db_session, sample_issues, order_by, expected_first_number ): """Test different ordering options.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(order_by=order_by) @@ -880,7 +897,7 @@ async def test_list_github_issues_ordering( ) async def test_github_work_summary_all_group_by_options(db_session, sample_issues, group_by): """Test all valid group_by options.""" - from memory.api.MCP.github import github_work_summary + from memory.api.MCP.servers.github import github_work_summary now = datetime.now(timezone.utc) since = (now - timedelta(days=30)).isoformat() @@ -906,7 +923,7 @@ async def test_list_github_issues_label_filtering( db_session, sample_issues, labels, expected_count ): """Test various label filtering scenarios.""" - from memory.api.MCP.github import list_github_issues + from memory.api.MCP.servers.github import list_github_issues with patch("memory.api.MCP.github.make_session", return_value=db_session): results = await list_github_issues(labels=labels) @@ -1014,7 +1031,7 @@ def sample_pr_with_data(db_session): @pytest.mark.asyncio async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_data): """Test that github_issue_details includes PR data for PRs.""" - from memory.api.MCP.github import github_issue_details + from memory.api.MCP.servers.github import github_issue_details with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_issue_details(repo="owner/repo1", number=999) @@ -1034,7 +1051,7 @@ async def test_github_issue_details_includes_pr_data(db_session, sample_pr_with_ @pytest.mark.asyncio async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_issues): """Test that github_issue_details does not include pr_data for issues.""" - from memory.api.MCP.github import github_issue_details + from memory.api.MCP.servers.github import github_issue_details with patch("memory.api.MCP.github.make_session", return_value=db_session): result = await github_issue_details(repo="owner/repo1", number=1) @@ -1045,7 +1062,7 @@ async def test_github_issue_details_no_pr_data_for_issues(db_session, sample_iss def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data): """Test that _serialize_issue includes pr_data when include_content=True.""" - from memory.api.MCP.github import _serialize_issue + from memory.api.MCP.servers.github import _serialize_issue result = _serialize_issue(sample_pr_with_data, include_content=True) @@ -1056,7 +1073,7 @@ def test_serialize_issue_includes_pr_data(db_session, sample_pr_with_data): def test_serialize_issue_no_pr_data_without_content(db_session, sample_pr_with_data): """Test that _serialize_issue excludes pr_data when include_content=False.""" - from memory.api.MCP.github import _serialize_issue + from memory.api.MCP.servers.github import _serialize_issue result = _serialize_issue(sample_pr_with_data, include_content=False) diff --git a/tests/memory/api/MCP/test_people.py b/tests/memory/api/MCP/test_people.py index d4ec183..6157a94 100644 --- a/tests/memory/api/MCP/test_people.py +++ b/tests/memory/api/MCP/test_people.py @@ -4,10 +4,27 @@ import sys import pytest from unittest.mock import AsyncMock, Mock, MagicMock, patch -# Mock the mcp module and all its submodules before importing anything that uses it +# Mock FastMCP - this creates a decorator factory that passes through the function unchanged +class MockFastMCP: + def __init__(self, name): + self.name = name + + def tool(self): + def decorator(func): + return func + return decorator + + +# Mock the fastmcp module before importing anything that uses it +_mock_fastmcp = MagicMock() +_mock_fastmcp.FastMCP = MockFastMCP +sys.modules["fastmcp"] = _mock_fastmcp + +# Mock the mcp module and all its submodules _mock_mcp = MagicMock() -_mock_mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator +_mock_mcp.tool = lambda: lambda f: f sys.modules["mcp"] = _mock_mcp +sys.modules["mcp.types"] = MagicMock() sys.modules["mcp.server"] = MagicMock() sys.modules["mcp.server.auth"] = MagicMock() sys.modules["mcp.server.auth.handlers"] = MagicMock() @@ -20,7 +37,7 @@ sys.modules["mcp.server.fastmcp.server"] = MagicMock() # Also mock the memory.api.MCP.base module to avoid MCP imports _mock_base = MagicMock() _mock_base.mcp = MagicMock() -_mock_base.mcp.tool = lambda: lambda f: f # Make @mcp.tool() a no-op decorator +_mock_base.mcp.tool = lambda: lambda f: f sys.modules["memory.api.MCP.base"] = _mock_base from memory.common.db.models import Person @@ -97,7 +114,7 @@ def sample_people(db_session): @pytest.mark.asyncio async def test_add_person_success(db_session): """Test adding a new person.""" - from memory.api.MCP.people import add_person + from memory.api.MCP.servers.people import add_person mock_task = Mock() mock_task.id = "task-123" @@ -128,7 +145,7 @@ async def test_add_person_success(db_session): @pytest.mark.asyncio async def test_add_person_already_exists(db_session, sample_people): """Test adding a person that already exists.""" - from memory.api.MCP.people import add_person + from memory.api.MCP.servers.people import add_person with patch("memory.api.MCP.people.make_session", return_value=db_session): with pytest.raises(ValueError, match="already exists"): @@ -141,7 +158,7 @@ async def test_add_person_already_exists(db_session, sample_people): @pytest.mark.asyncio async def test_add_person_minimal(db_session): """Test adding a person with minimal data.""" - from memory.api.MCP.people import add_person + from memory.api.MCP.servers.people import add_person mock_task = Mock() mock_task.id = "task-456" @@ -166,7 +183,7 @@ async def test_add_person_minimal(db_session): @pytest.mark.asyncio async def test_update_person_info_success(db_session, sample_people): """Test updating a person's info.""" - from memory.api.MCP.people import update_person_info + from memory.api.MCP.servers.people import update_person_info mock_task = Mock() mock_task.id = "task-789" @@ -188,7 +205,7 @@ async def test_update_person_info_success(db_session, sample_people): @pytest.mark.asyncio async def test_update_person_info_not_found(db_session, sample_people): """Test updating a person that doesn't exist.""" - from memory.api.MCP.people import update_person_info + from memory.api.MCP.servers.people import update_person_info with patch("memory.api.MCP.people.make_session", return_value=db_session): with pytest.raises(ValueError, match="not found"): @@ -201,7 +218,7 @@ async def test_update_person_info_not_found(db_session, sample_people): @pytest.mark.asyncio async def test_update_person_info_with_merge_params(db_session, sample_people): """Test that update passes all merge parameters.""" - from memory.api.MCP.people import update_person_info + from memory.api.MCP.servers.people import update_person_info mock_task = Mock() mock_task.id = "task-merge" @@ -234,7 +251,7 @@ async def test_update_person_info_with_merge_params(db_session, sample_people): @pytest.mark.asyncio async def test_get_person_found(db_session, sample_people): """Test getting a person that exists.""" - from memory.api.MCP.people import get_person + from memory.api.MCP.servers.people import get_person with patch("memory.api.MCP.people.make_session", return_value=db_session): result = await get_person(identifier="alice_chen") @@ -251,7 +268,7 @@ async def test_get_person_found(db_session, sample_people): @pytest.mark.asyncio async def test_get_person_not_found(db_session, sample_people): """Test getting a person that doesn't exist.""" - from memory.api.MCP.people import get_person + from memory.api.MCP.servers.people import get_person with patch("memory.api.MCP.people.make_session", return_value=db_session): result = await get_person(identifier="nonexistent_person") @@ -267,7 +284,7 @@ async def test_get_person_not_found(db_session, sample_people): @pytest.mark.asyncio async def test_list_people_no_filters(db_session, sample_people): """Test listing all people without filters.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): results = await list_people() @@ -282,7 +299,7 @@ async def test_list_people_no_filters(db_session, sample_people): @pytest.mark.asyncio async def test_list_people_filter_by_tags(db_session, sample_people): """Test filtering by tags.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): results = await list_people(tags=["work"]) @@ -294,7 +311,7 @@ async def test_list_people_filter_by_tags(db_session, sample_people): @pytest.mark.asyncio async def test_list_people_filter_by_search(db_session, sample_people): """Test filtering by search term.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): results = await list_people(search="alice") @@ -306,7 +323,7 @@ async def test_list_people_filter_by_search(db_session, sample_people): @pytest.mark.asyncio async def test_list_people_search_in_notes(db_session, sample_people): """Test that search works on notes content.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): results = await list_people(search="climbing") @@ -318,7 +335,7 @@ async def test_list_people_search_in_notes(db_session, sample_people): @pytest.mark.asyncio async def test_list_people_limit(db_session, sample_people): """Test limiting results.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): results = await list_people(limit=1) @@ -329,7 +346,7 @@ async def test_list_people_limit(db_session, sample_people): @pytest.mark.asyncio async def test_list_people_limit_max_enforced(db_session, sample_people): """Test that limit is capped at 200.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): # Request 500 but should be capped at 200 @@ -342,7 +359,7 @@ async def test_list_people_limit_max_enforced(db_session, sample_people): @pytest.mark.asyncio async def test_list_people_combined_filters(db_session, sample_people): """Test combining tag and search filters.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): results = await list_people(tags=["work"], search="chen") @@ -359,7 +376,7 @@ async def test_list_people_combined_filters(db_session, sample_people): @pytest.mark.asyncio async def test_delete_person_success(db_session, sample_people): """Test deleting a person.""" - from memory.api.MCP.people import delete_person + from memory.api.MCP.servers.people import delete_person with patch("memory.api.MCP.people.make_session", return_value=db_session): result = await delete_person(identifier="bob_smith") @@ -376,7 +393,7 @@ async def test_delete_person_success(db_session, sample_people): @pytest.mark.asyncio async def test_delete_person_not_found(db_session, sample_people): """Test deleting a person that doesn't exist.""" - from memory.api.MCP.people import delete_person + from memory.api.MCP.servers.people import delete_person with patch("memory.api.MCP.people.make_session", return_value=db_session): with pytest.raises(ValueError, match="not found"): @@ -390,7 +407,7 @@ async def test_delete_person_not_found(db_session, sample_people): def test_person_to_dict(sample_people): """Test the _person_to_dict helper function.""" - from memory.api.MCP.people import _person_to_dict + from memory.api.MCP.servers.people import _person_to_dict person = sample_people[0] result = _person_to_dict(person) @@ -406,7 +423,7 @@ def test_person_to_dict(sample_people): def test_person_to_dict_empty_fields(db_session): """Test _person_to_dict with empty optional fields.""" - from memory.api.MCP.people import _person_to_dict + from memory.api.MCP.servers.people import _person_to_dict person = Person( identifier="empty_person", @@ -448,7 +465,7 @@ def test_person_to_dict_empty_fields(db_session): ) async def test_list_people_various_tags(db_session, sample_people, tag, expected_count): """Test filtering by various tags.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): results = await list_people(tags=[tag]) @@ -473,7 +490,7 @@ async def test_list_people_various_searches( db_session, sample_people, search_term, expected_identifiers ): """Test search with various terms.""" - from memory.api.MCP.people import list_people + from memory.api.MCP.servers.people import list_people with patch("memory.api.MCP.people.make_session", return_value=db_session): results = await list_people(search=search_term) diff --git a/tests/memory/api/search/test_types.py b/tests/memory/api/search/test_types.py index 86fd879..ac21417 100644 --- a/tests/memory/api/search/test_types.py +++ b/tests/memory/api/search/test_types.py @@ -69,8 +69,12 @@ class TestSearchResult: result = SearchResult.from_source_item(source, chunks) assert result.search_score == 0.9 - def test_search_score_multiple_chunks_uses_mean(self): - """Multiple chunks should use mean of relevance scores, not sum.""" + def test_search_score_multiple_chunks_uses_max(self): + """Multiple chunks should use max of relevance scores. + + Using max finds documents with at least one highly relevant section, + which is better for 'half-remembered' searches where users recall one detail. + """ source = self._make_source_item() chunks = [ self._make_chunk(0.9), @@ -79,8 +83,8 @@ class TestSearchResult: ] result = SearchResult.from_source_item(source, chunks) - # Mean of 0.9, 0.7, 0.8 = 0.8 - assert result.search_score == pytest.approx(0.8) + # Max of 0.9, 0.7, 0.8 = 0.9 + assert result.search_score == pytest.approx(0.9) def test_search_score_empty_chunks(self): """Empty chunk list should result in None or 0 score."""