From 48c380b9035bc8f1e484c65c543e4f9f9073d8c0 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Wed, 24 Dec 2025 23:19:41 +0100 Subject: [PATCH] migrate to fastmcp --- requirements/requirements-api.txt | 3 +- requirements/requirements-common.txt | 6 +- src/memory/api/MCP/__init__.py | 24 +- src/memory/api/MCP/base.py | 121 ++++---- src/memory/api/MCP/manifest.py | 119 -------- src/memory/api/MCP/metadata.py | 119 -------- src/memory/api/MCP/oauth_provider.py | 44 ++- src/memory/api/MCP/servers/__init__.py | 17 ++ src/memory/api/MCP/{ => servers}/books.py | 10 +- .../api/MCP/{memory.py => servers/core.py} | 63 ++-- src/memory/api/MCP/{ => servers}/github.py | 59 +--- src/memory/api/MCP/servers/meta.py | 277 ++++++++++++++++++ src/memory/api/MCP/{ => servers}/people.py | 30 +- .../MCP/{schedules.py => servers/schedule.py} | 37 ++- src/memory/api/MCP/tools.py | 74 ----- src/memory/api/app.py | 53 +--- tools/deploy.sh | 1 - tools/diagnose.sh | 195 ++++++++++++ 18 files changed, 703 insertions(+), 549 deletions(-) delete mode 100644 src/memory/api/MCP/manifest.py delete mode 100644 src/memory/api/MCP/metadata.py create mode 100644 src/memory/api/MCP/servers/__init__.py rename src/memory/api/MCP/{ => servers}/books.py (92%) rename src/memory/api/MCP/{memory.py => servers/core.py} (91%) rename src/memory/api/MCP/{ => servers}/github.py (92%) create mode 100644 src/memory/api/MCP/servers/meta.py rename src/memory/api/MCP/{ => servers}/people.py (95%) rename src/memory/api/MCP/{schedules.py => servers/schedule.py} (89%) delete mode 100644 src/memory/api/MCP/tools.py create mode 100755 tools/diagnose.sh diff --git a/requirements/requirements-api.txt b/requirements/requirements-api.txt index 3e1a402..0fc683f 100644 --- a/requirements/requirements-api.txt +++ b/requirements/requirements-api.txt @@ -1,7 +1,8 @@ -fastapi==0.112.2 +fastapi>=0.115.12 uvicorn==0.29.0 python-jose==3.3.0 python-multipart==0.0.9 sqladmin==0.20.1 mcp==1.10.0 +fastmcp>=2.10.0 slowapi==0.1.9 \ No newline at end of file diff --git a/requirements/requirements-common.txt b/requirements/requirements-common.txt index 09674e1..9269445 100644 --- a/requirements/requirements-common.txt +++ b/requirements/requirements-common.txt @@ -1,14 +1,14 @@ sqlalchemy==2.0.30 psycopg2-binary==2.9.9 -pydantic==2.7.2 +pydantic>=2.11.7 alembic==1.13.1 dotenv==0.9.9 voyageai==0.3.2 qdrant-client==1.9.0 anthropic==0.69.0 openai==2.3.0 -# Pin the httpx version, as newer versions break the anthropic client -httpx==0.27.0 +# Updated for fastmcp>=2.10 compatibility (anthropic 0.69.0 supports httpx<1) +httpx>=0.28.1 celery[redis,sqs]==5.3.6 cryptography==43.0.0 bcrypt==4.1.2 \ No newline at end of file diff --git a/src/memory/api/MCP/__init__.py b/src/memory/api/MCP/__init__.py index 07f0efa..540920d 100644 --- a/src/memory/api/MCP/__init__.py +++ b/src/memory/api/MCP/__init__.py @@ -1,8 +1,16 @@ -import memory.api.MCP.tools -import memory.api.MCP.memory -import memory.api.MCP.metadata -import memory.api.MCP.schedules -import memory.api.MCP.books -import memory.api.MCP.manifest -import memory.api.MCP.github -import memory.api.MCP.people +""" +MCP server with composed subservers. + +Subservers are mounted with prefixes: +- core: search_knowledge_base, observe, search_observations, create_note, note_files, fetch_file +- github: list_github_issues, search_github_issues, github_issue_details, github_work_summary, github_repo_overview +- people: add_person, update_person_info, get_person, list_people, delete_person +- schedule: schedule_message, list_scheduled_llm_calls, cancel_scheduled_llm_call +- books: all_books, read_book +- meta: get_metadata_schemas, get_all_tags, get_all_subjects, get_all_observation_types, get_current_time, get_authenticated_user, get_forecasts +""" + +# Import base to trigger subserver mounting +from memory.api.MCP.base import mcp, get_current_user + +__all__ = ["mcp", "get_current_user"] diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py index 92c4ad0..719c4e0 100644 --- a/src/memory/api/MCP/base.py +++ b/src/memory/api/MCP/base.py @@ -1,60 +1,28 @@ import logging import pathlib -from typing import cast -from mcp.server.auth.handlers.authorize import AuthorizationRequest -from mcp.server.auth.handlers.token import ( - AuthorizationCodeRequest, - RefreshTokenRequest, - TokenRequest, -) -from mcp.server.auth.middleware.auth_context import get_access_token -from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions -from mcp.server.fastmcp import FastMCP -from mcp.shared.auth import OAuthClientMetadata -from pydantic import AnyHttpUrl +from fastmcp import FastMCP +from fastmcp.server.dependencies import get_access_token from starlette.requests import Request from starlette.responses import JSONResponse, RedirectResponse from starlette.templating import Jinja2Templates -from memory.api.MCP.oauth_provider import ( - ALLOWED_SCOPES, - BASE_SCOPES, - SimpleOAuthProvider, -) +from memory.api.MCP.oauth_provider import SimpleOAuthProvider +from memory.api.MCP.servers.books import books_mcp +from memory.api.MCP.servers.core import core_mcp +from memory.api.MCP.servers.github import github_mcp +from memory.api.MCP.servers.meta import meta_mcp +from memory.api.MCP.servers.meta import set_auth_provider as set_meta_auth +from memory.api.MCP.servers.people import people_mcp +from memory.api.MCP.servers.schedule import schedule_mcp +from memory.api.MCP.servers.schedule import set_auth_provider as set_schedule_auth from memory.common import settings -from memory.common.db.connection import make_session +from memory.common.db.connection import make_session, get_engine from memory.common.db.models import OAuthState, UserSession from memory.common.db.models.users import HumanUser logger = logging.getLogger(__name__) - - -def validate_metadata(klass: type): - orig_validate = klass.model_validate - - def validate(data: dict): - data = dict(data) - if "redirect_uris" in data: - data["redirect_uris"] = [ - str(uri).replace("cursor://", "http://") - for uri in data["redirect_uris"] - ] - if "redirect_uri" in data: - data["redirect_uri"] = str(data["redirect_uri"]).replace( - "cursor://", "http://" - ) - - return orig_validate(data) - - klass.model_validate = validate - - -validate_metadata(OAuthClientMetadata) -validate_metadata(AuthorizationRequest) -validate_metadata(AuthorizationCodeRequest) -validate_metadata(RefreshTokenRequest) -validate_metadata(TokenRequest) +engine = get_engine() # Setup templates @@ -63,22 +31,10 @@ templates = Jinja2Templates(directory=template_dir) oauth_provider = SimpleOAuthProvider() -auth_settings = AuthSettings( - issuer_url=cast(AnyHttpUrl, settings.SERVER_URL), - resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore - client_registration_options=ClientRegistrationOptions( - enabled=True, - valid_scopes=ALLOWED_SCOPES, - default_scopes=BASE_SCOPES, - ), - required_scopes=BASE_SCOPES, -) mcp = FastMCP( "memory", - stateless_http=True, - auth_server_provider=oauth_provider, - auth=auth_settings, + auth=oauth_provider, ) @@ -162,3 +118,52 @@ def get_current_user() -> dict: "client_id": access_token.client_id, "user": user_info, } + + +@mcp.custom_route("/health", methods=["GET"]) +async def health_check(request: Request): + """Health check endpoint that verifies all dependencies are accessible.""" + from sqlalchemy import text + + checks = {"mcp_oauth": "enabled"} + all_healthy = True + + # Check database connection + try: + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + checks["database"] = "healthy" + except Exception as e: + logger.error(f"Database health check failed: {e}") + checks["database"] = "unhealthy" + all_healthy = False + + # Check Qdrant connection + try: + from memory.common.qdrant import get_qdrant_client + + client = get_qdrant_client() + client.get_collections() + checks["qdrant"] = "healthy" + except Exception as e: + logger.error(f"Qdrant health check failed: {e}") + checks["qdrant"] = "unhealthy" + all_healthy = False + + checks["status"] = "healthy" if all_healthy else "degraded" + status_code = 200 if all_healthy else 503 + return JSONResponse(checks, status_code=status_code) + + +# Inject auth provider into subservers that need it +set_schedule_auth(get_current_user) +set_meta_auth(get_current_user) + +# Mount all subservers onto the main MCP server +# Tools will be prefixed with their server name (e.g., core_search_knowledge_base) +mcp.mount(core_mcp, prefix="core") +mcp.mount(github_mcp, prefix="github") +mcp.mount(people_mcp, prefix="people") +mcp.mount(schedule_mcp, prefix="schedule") +mcp.mount(books_mcp, prefix="books") +mcp.mount(meta_mcp, prefix="meta") diff --git a/src/memory/api/MCP/manifest.py b/src/memory/api/MCP/manifest.py deleted file mode 100644 index 6dcd76a..0000000 --- a/src/memory/api/MCP/manifest.py +++ /dev/null @@ -1,119 +0,0 @@ -import asyncio -import aiohttp -from datetime import datetime - -from typing import TypedDict, NotRequired, Literal -from memory.api.MCP.tools import mcp - - -class BinaryProbs(TypedDict): - prob: float - - -class MultiProbs(TypedDict): - answerProbs: dict[str, float] - - -Probs = dict[str, BinaryProbs | MultiProbs] -OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"] - - -class MarketAnswer(TypedDict): - id: str - text: str - resolutionProbability: float - - -class MarketDetails(TypedDict): - id: str - createdTime: int - question: str - outcomeType: OutcomeType - textDescription: str - groupSlugs: list[str] - volume: float - isResolved: bool - answers: list[MarketAnswer] - - -class Market(TypedDict): - id: str - url: str - question: str - volume: int - createdTime: int - outcomeType: OutcomeType - createdAt: NotRequired[str] - description: NotRequired[str] - answers: NotRequired[dict[str, float]] - probability: NotRequired[float] - details: NotRequired[MarketDetails] - - -async def get_details(session: aiohttp.ClientSession, market_id: str): - async with session.get( - f"https://api.manifold.markets/v0/market/{market_id}" - ) as resp: - resp.raise_for_status() - return await resp.json() - - -async def format_market(session: aiohttp.ClientSession, market: Market): - if market.get("outcomeType") != "BINARY": - details = await get_details(session, market["id"]) - market["answers"] = { - answer["text"]: round( - answer.get("resolutionProbability") or answer.get("probability") or 0, 3 - ) - for answer in details["answers"] - } - if creationTime := market.get("createdTime"): - market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat() - - fields = [ - "id", - "name", - "url", - "question", - "volume", - "createdAt", - "details", - "probability", - "answers", - ] - return {k: v for k, v in market.items() if k in fields} - - -async def search_markets(term: str, min_volume: int = 1000, binary: bool = False): - async with aiohttp.ClientSession() as session: - async with session.get( - "https://api.manifold.markets/v0/search-markets", - params={ - "term": term, - "contractType": "BINARY" if binary else "ALL", - }, - ) as resp: - resp.raise_for_status() - markets = await resp.json() - - return await asyncio.gather( - *[ - format_market(session, market) - for market in markets - if market.get("volume", 0) >= min_volume - ] - ) - - -@mcp.tool() -async def get_forecasts( - term: str, min_volume: int = 1000, binary: bool = False -) -> list[dict]: - """Get prediction market forecasts for a given term. - - Args: - term: The term to search for. - min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold. - binary: Whether to only return binary markets. - """ - return await search_markets(term, min_volume, binary) diff --git a/src/memory/api/MCP/metadata.py b/src/memory/api/MCP/metadata.py deleted file mode 100644 index ce54bec..0000000 --- a/src/memory/api/MCP/metadata.py +++ /dev/null @@ -1,119 +0,0 @@ -import logging -from collections import defaultdict -from typing import Annotated, TypedDict, get_args, get_type_hints - -from memory.common import qdrant -from sqlalchemy import func - -from memory.api.MCP.tools import mcp -from memory.common.db.connection import make_session -from memory.common.db.models import SourceItem -from memory.common.db.models.source_items import AgentObservation - -logger = logging.getLogger(__name__) - - -class SchemaArg(TypedDict): - type: str | None - description: str | None - - -class CollectionMetadata(TypedDict): - schema: dict[str, SchemaArg] - size: int - - -def from_annotation(annotation: Annotated) -> SchemaArg | None: - try: - type_, description = get_args(annotation) - type_str = str(type_) - if type_str.startswith("typing."): - type_str = type_str[7:] - elif len((parts := type_str.split("'"))) > 1: - type_str = parts[1] - return SchemaArg(type=type_str, description=description) - except IndexError: - logger.error(f"Error from annotation: {annotation}") - return None - - -def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]: - if not hasattr(klass, "as_payload"): - return {} - - if not (payload_type := get_type_hints(klass.as_payload).get("return")): - return {} - - return { - name: schema - for name, arg in payload_type.__annotations__.items() - if (schema := from_annotation(arg)) - } - - -@mcp.tool() -async def get_metadata_schemas() -> dict[str, CollectionMetadata]: - """Get the metadata schema for each collection used in the knowledge base. - - These schemas can be used to filter the knowledge base. - - Returns: A mapping of collection names to their metadata schemas with field types and descriptions. - - Example: - ``` - { - "mail": {"subject": {"type": "str", "description": "The subject of the email."}}, - "chat": {"subject": {"type": "str", "description": "The subject of the chat message."}} - } - """ - client = qdrant.get_qdrant_client() - sizes = qdrant.get_collection_sizes(client) - schemas = defaultdict(dict) - for klass in SourceItem.__subclasses__(): - for collection in klass.get_collections(): - schemas[collection].update(get_schema(klass)) - - return { - collection: CollectionMetadata(schema=schema, size=size) - for collection, schema in schemas.items() - if (size := sizes.get(collection)) - } - - -@mcp.tool() -async def get_all_tags() -> list[str]: - """Get all unique tags used across the entire knowledge base. - - Returns sorted list of tags from both observations and content. - """ - with make_session() as session: - tags_query = session.query(func.unnest(SourceItem.tags)).distinct() - return sorted({row[0] for row in tags_query if row[0] is not None}) - - -@mcp.tool() -async def get_all_subjects() -> list[str]: - """Get all unique subjects from observations about the user. - - Returns sorted list of subject identifiers used in observations. - """ - with make_session() as session: - return sorted( - r.subject for r in session.query(AgentObservation.subject).distinct() - ) - - -@mcp.tool() -async def get_all_observation_types() -> list[str]: - """Get all observation types that have been used. - - Standard types are belief, preference, behavior, contradiction, general, but there can be more. - """ - with make_session() as session: - return sorted( - { - r.observation_type - for r in session.query(AgentObservation.observation_type).distinct() - if r.observation_type is not None - } - ) diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py index 3f9ac9f..db5b6f4 100644 --- a/src/memory/api/MCP/oauth_provider.py +++ b/src/memory/api/MCP/oauth_provider.py @@ -5,11 +5,12 @@ from datetime import datetime, timezone from typing import Any, Optional, cast from urllib.parse import parse_qs, urlencode, urlparse, urlunparse +from fastmcp.server.auth import OAuthProvider +from fastmcp.server.auth.auth import AccessToken as FastMCPAccessToken from mcp.server.auth.provider import ( AccessToken, AuthorizationCode, AuthorizationParams, - OAuthAuthorizationServerProvider, RefreshToken, ) from mcp.shared.auth import OAuthClientInformationFull, OAuthToken @@ -133,8 +134,45 @@ def make_token( ) -class SimpleOAuthProvider(OAuthAuthorizationServerProvider): - async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: +class SimpleOAuthProvider(OAuthProvider): + """OAuth provider that extends fastmcp's OAuthProvider with custom login flow.""" + + def __init__(self): + super().__init__( + base_url=settings.SERVER_URL, + issuer_url=settings.SERVER_URL, + client_registration_options=None, # We handle registration ourselves + required_scopes=BASE_SCOPES, + ) + + async def verify_token(self, token: str) -> FastMCPAccessToken | None: + """Verify an access token and return token info if valid.""" + with make_session() as session: + # Try as OAuth access token first + user_session = session.query(UserSession).get(token) + if user_session: + now = datetime.now(timezone.utc).replace(tzinfo=None) + if user_session.expires_at < now: + return None + return FastMCPAccessToken( + token=token, + client_id=user_session.oauth_state.client_id, + scopes=user_session.oauth_state.scopes, + ) + + # Try as bot API key + bot = session.query(User).filter(User.api_key == token).first() + if bot: + logger.info(f"Bot {bot.name} (id={bot.id}) authenticated via API key") + return FastMCPAccessToken( + token=token, + client_id=cast(str, bot.name or bot.email), + scopes=["read", "write"], + ) + + return None + + def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """Get OAuth client information.""" with make_session() as session: client = session.get(OAuthClientInformation, client_id) diff --git a/src/memory/api/MCP/servers/__init__.py b/src/memory/api/MCP/servers/__init__.py new file mode 100644 index 0000000..68122c0 --- /dev/null +++ b/src/memory/api/MCP/servers/__init__.py @@ -0,0 +1,17 @@ +"""MCP subservers for composable tool organization.""" + +from memory.api.MCP.servers.core import core_mcp +from memory.api.MCP.servers.github import github_mcp +from memory.api.MCP.servers.people import people_mcp +from memory.api.MCP.servers.schedule import schedule_mcp +from memory.api.MCP.servers.books import books_mcp +from memory.api.MCP.servers.meta import meta_mcp + +__all__ = [ + "core_mcp", + "github_mcp", + "people_mcp", + "schedule_mcp", + "books_mcp", + "meta_mcp", +] diff --git a/src/memory/api/MCP/books.py b/src/memory/api/MCP/servers/books.py similarity index 92% rename from src/memory/api/MCP/books.py rename to src/memory/api/MCP/servers/books.py index ad102ef..d5989ff 100644 --- a/src/memory/api/MCP/books.py +++ b/src/memory/api/MCP/servers/books.py @@ -1,15 +1,19 @@ +"""MCP subserver for ebook access.""" + import logging +from fastmcp import FastMCP from sqlalchemy.orm import joinedload -from memory.api.MCP.tools import mcp from memory.common.db.connection import make_session from memory.common.db.models import Book, BookSection, BookSectionPayload logger = logging.getLogger(__name__) +books_mcp = FastMCP("memory-books") -@mcp.tool() + +@books_mcp.tool() async def all_books(sections: bool = False) -> list[dict]: """ Get all books in the database. @@ -31,7 +35,7 @@ async def all_books(sections: bool = False) -> list[dict]: return [book.as_payload(sections=sections) for book in books] -@mcp.tool() +@books_mcp.tool() def read_book(book_id: int, sections: list[int] = []) -> list[BookSectionPayload]: """ Read a book from the database. diff --git a/src/memory/api/MCP/memory.py b/src/memory/api/MCP/servers/core.py similarity index 91% rename from src/memory/api/MCP/memory.py rename to src/memory/api/MCP/servers/core.py index 8ea3fd0..f73f251 100644 --- a/src/memory/api/MCP/memory.py +++ b/src/memory/api/MCP/servers/core.py @@ -1,53 +1,45 @@ """ -MCP tools for the epistemic sparring partner system. +Core MCP subserver for knowledge base search, observations, and notes. """ import base64 import logging import pathlib from datetime import datetime, timezone -from PIL import Image +from fastmcp import FastMCP +from PIL import Image from pydantic import BaseModel from sqlalchemy import Text from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY -from memory.api.MCP.base import mcp from memory.api.search.search import search -from memory.api.search.types import SearchFilters, SearchConfig +from memory.api.search.types import SearchConfig, SearchFilters from memory.common import extract, settings from memory.common.celery_app import SYNC_NOTE, SYNC_OBSERVATION from memory.common.celery_app import app as celery_app from memory.common.collections import ALL_COLLECTIONS, OBSERVATION_COLLECTIONS from memory.common.db.connection import make_session -from memory.common.db.models import SourceItem, AgentObservation +from memory.common.db.models import AgentObservation, SourceItem from memory.common.formatters import observation logger = logging.getLogger(__name__) +core_mcp = FastMCP("memory-core") + def validate_path_within_directory( base_dir: pathlib.Path, requested_path: str ) -> pathlib.Path: - """Validate that a requested path resolves within the base directory. - - Prevents path traversal attacks using ../ or similar techniques. - - Args: - base_dir: The allowed base directory - requested_path: The user-provided path - - Returns: - The resolved absolute path if valid - - Raises: - ValueError: If the path would escape the base directory - """ + """Validate that a requested path resolves within the base directory.""" resolved = (base_dir / requested_path.lstrip("/")).resolve() base_resolved = base_dir.resolve() - if not str(resolved).startswith(str(base_resolved) + "/") and resolved != base_resolved: + if ( + not str(resolved).startswith(str(base_resolved) + "/") + and resolved != base_resolved + ): raise ValueError(f"Path escapes allowed directory: {requested_path}") return resolved @@ -63,7 +55,6 @@ def filter_observation_source_ids( items_query = session.query(AgentObservation.id) if tags: - # Use PostgreSQL array overlap operator with proper array casting items_query = items_query.filter( AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))), ) @@ -89,7 +80,6 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] items_query = session.query(SourceItem.id) if tags: - # Use PostgreSQL array overlap operator with proper array casting items_query = items_query.filter( SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))), ) @@ -102,7 +92,7 @@ def filter_source_ids(modalities: set[str], filters: SearchFilters) -> list[int] return source_ids -@mcp.tool() +@core_mcp.tool() async def search_knowledge_base( query: str, filters: SearchFilters = {}, @@ -141,7 +131,6 @@ async def search_knowledge_base( if not modalities: modalities = set(ALL_COLLECTIONS.keys()) - # Filter to valid collections, excluding observation collections modalities = (set(modalities) & ALL_COLLECTIONS.keys()) - OBSERVATION_COLLECTIONS search_filters = SearchFilters(**filters) @@ -167,7 +156,7 @@ class RawObservation(BaseModel): tags: list[str] = [] -@mcp.tool() +@core_mcp.tool() async def observe( observations: list[RawObservation], session_id: str | None = None, @@ -212,23 +201,23 @@ async def observe( logger.info("MCP: Observing") tasks = [ ( - observation, + obs, celery_app.send_task( SYNC_OBSERVATION, queue=f"{settings.CELERY_QUEUE_PREFIX}-notes", kwargs={ - "subject": observation.subject, - "content": observation.content, - "observation_type": observation.observation_type, - "confidences": observation.confidences, - "evidence": observation.evidence, - "tags": observation.tags, + "subject": obs.subject, + "content": obs.content, + "observation_type": obs.observation_type, + "confidences": obs.confidences, + "evidence": obs.evidence, + "tags": obs.tags, "session_id": session_id, "agent_model": agent_model, }, ), ) - for observation in observations + for obs in observations ] def short_content(obs: RawObservation) -> str: @@ -242,7 +231,7 @@ async def observe( } -@mcp.tool() +@core_mcp.tool() async def search_observations( query: str, subject: str = "", @@ -308,7 +297,7 @@ async def search_observations( ] -@mcp.tool() +@core_mcp.tool() async def create_note( subject: str, content: str, @@ -362,7 +351,7 @@ async def create_note( } -@mcp.tool() +@core_mcp.tool() async def note_files(path: str = "/"): """ List note files in the user's note storage. @@ -386,7 +375,7 @@ async def note_files(path: str = "/"): ] -@mcp.tool() +@core_mcp.tool() def fetch_file(filename: str) -> dict: """ Read file content with automatic type detection. diff --git a/src/memory/api/MCP/github.py b/src/memory/api/MCP/servers/github.py similarity index 92% rename from src/memory/api/MCP/github.py rename to src/memory/api/MCP/servers/github.py index 859c224..624c49f 100644 --- a/src/memory/api/MCP/github.py +++ b/src/memory/api/MCP/servers/github.py @@ -1,14 +1,14 @@ -"""MCP tools for GitHub issue tracking and management.""" +"""MCP subserver for GitHub issue tracking and management.""" import logging from datetime import datetime, timezone from typing import Any -from sqlalchemy import Text, case, desc, func, asc +from fastmcp import FastMCP +from sqlalchemy import Text, case, desc, func from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY -from memory.api.MCP.base import mcp from memory.api.search.search import search from memory.api.search.types import SearchConfig, SearchFilters from memory.common import extract @@ -17,6 +17,8 @@ from memory.common.db.models import GithubItem logger = logging.getLogger(__name__) +github_mcp = FastMCP("memory-github") + def _build_github_url(repo_path: str, number: int | None, kind: str) -> str: """Build GitHub URL from repo path and issue/PR number.""" @@ -53,7 +55,6 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st } if include_content: result["content"] = item.content - # Include PR-specific data if available if item.kind == "pr" and item.pr_data: result["pr_data"] = { "additions": item.pr_data.additions, @@ -62,12 +63,12 @@ def _serialize_issue(item: GithubItem, include_content: bool = False) -> dict[st "files": item.pr_data.files, "reviews": item.pr_data.reviews, "review_comments": item.pr_data.review_comments, - "diff": item.pr_data.diff, # Decompressed via property + "diff": item.pr_data.diff, } return result -@mcp.tool() +@github_mcp.tool() async def list_github_issues( repo: str | None = None, assignee: str | None = None, @@ -109,64 +110,48 @@ async def list_github_issues( with make_session() as session: query = session.query(GithubItem) - # Apply filters if repo: query = query.filter(GithubItem.repo_path == repo) - if assignee: query = query.filter(GithubItem.assignees.any(assignee)) - if author: query = query.filter(GithubItem.author == author) - if state: query = query.filter(GithubItem.state == state) - if kind: query = query.filter(GithubItem.kind == kind) else: - # Exclude comments by default, only show issues and PRs query = query.filter(GithubItem.kind.in_(["issue", "pr"])) - if labels: - # Match any label in the list using PostgreSQL array overlap query = query.filter( GithubItem.labels.op("&&")(sql_cast(labels, ARRAY(Text))) ) - if project_status: query = query.filter(GithubItem.project_status == project_status) - if project_field: for key, value in project_field.items(): - query = query.filter( - GithubItem.project_fields[key].astext == value - ) - + query = query.filter(GithubItem.project_fields[key].astext == value) if updated_since: since_dt = datetime.fromisoformat(updated_since.replace("Z", "+00:00")) query = query.filter(GithubItem.github_updated_at >= since_dt) - if updated_before: before_dt = datetime.fromisoformat(updated_before.replace("Z", "+00:00")) query = query.filter(GithubItem.github_updated_at <= before_dt) - # Apply ordering if order_by == "created": query = query.order_by(desc(GithubItem.created_at)) elif order_by == "number": query = query.order_by(desc(GithubItem.number)) - else: # default: updated + else: query = query.order_by(desc(GithubItem.github_updated_at)) query = query.limit(limit) - items = query.all() return [_serialize_issue(item) for item in items] -@mcp.tool() +@github_mcp.tool() async def search_github_issues( query: str, repo: str | None = None, @@ -191,7 +176,6 @@ async def search_github_issues( limit = min(limit, 100) - # Pre-filter source_ids if repo/state/kind filters are specified source_ids = None if repo or state or kind: with make_session() as session: @@ -206,7 +190,6 @@ async def search_github_issues( q = q.filter(GithubItem.kind.in_(["issue", "pr"])) source_ids = [item.id for item in q.all()] - # Use the existing search infrastructure data = extract.extract_text(query, skip_summary=True) config = SearchConfig(limit=limit, previews=True) filters = SearchFilters() @@ -220,7 +203,6 @@ async def search_github_issues( config=config, ) - # Fetch full issue details for the results output = [] with make_session() as session: for result in results: @@ -233,7 +215,7 @@ async def search_github_issues( return output -@mcp.tool() +@github_mcp.tool() async def github_issue_details( repo: str, number: int, @@ -267,7 +249,7 @@ async def github_issue_details( return _serialize_issue(item, include_content=True) -@mcp.tool() +@github_mcp.tool() async def github_work_summary( since: str, until: str | None = None, @@ -294,7 +276,6 @@ async def github_work_summary( else: until_dt = datetime.now(timezone.utc) - # Map group_by to SQL expression group_mappings = { "client": GithubItem.project_fields["EquiStamp.Client"].astext, "status": GithubItem.project_status, @@ -311,7 +292,6 @@ async def github_work_summary( group_col = group_mappings[group_by] with make_session() as session: - # Build base query for the period base_query = session.query(GithubItem).filter( GithubItem.github_updated_at >= since_dt, GithubItem.github_updated_at <= until_dt, @@ -321,7 +301,6 @@ async def github_work_summary( if repo: base_query = base_query.filter(GithubItem.repo_path == repo) - # Get aggregated counts by group agg_query = ( session.query( group_col.label("group_name"), @@ -346,7 +325,6 @@ async def github_work_summary( groups = agg_query.all() - # Build summary with sample issues for each group summary = [] total_issues = 0 total_prs = 0 @@ -358,7 +336,6 @@ async def github_work_summary( total_issues += issue_count total_prs += pr_count - # Get sample issues for this group sample_query = base_query.filter(group_col == group_name).limit(5) samples = [ { @@ -394,7 +371,7 @@ async def github_work_summary( } -@mcp.tool() +@github_mcp.tool() async def github_repo_overview( repo: str, ) -> dict: @@ -410,13 +387,6 @@ async def github_repo_overview( logger.info(f"github_repo_overview called: repo={repo}") with make_session() as session: - # Base query for this repo - base_query = session.query(GithubItem).filter( - GithubItem.repo_path == repo, - GithubItem.kind.in_(["issue", "pr"]), - ) - - # Get total counts counts_query = session.query( func.count(GithubItem.id).label("total"), func.count(case((GithubItem.kind == "issue", 1))).label("total_issues"), @@ -441,7 +411,6 @@ async def github_repo_overview( counts = counts_query.first() - # Status breakdown (for project_status) status_query = ( session.query( GithubItem.project_status.label("status"), @@ -458,7 +427,6 @@ async def github_repo_overview( status_breakdown = {row.status: row.count for row in status_query.all()} - # Top assignees (open issues only) assignee_query = ( session.query( func.unnest(GithubItem.assignees).label("assignee"), @@ -479,7 +447,6 @@ async def github_repo_overview( for row in assignee_query.all() ] - # Label counts label_query = ( session.query( func.unnest(GithubItem.labels).label("label"), diff --git a/src/memory/api/MCP/servers/meta.py b/src/memory/api/MCP/servers/meta.py new file mode 100644 index 0000000..c9f361d --- /dev/null +++ b/src/memory/api/MCP/servers/meta.py @@ -0,0 +1,277 @@ +"""MCP subserver for metadata, utilities, and forecasting.""" + +import asyncio +import logging +from collections import defaultdict +from datetime import datetime, timezone +from typing import Annotated, Literal, NotRequired, TypedDict, get_args, get_type_hints + +import aiohttp +from fastmcp import FastMCP +from sqlalchemy import func + +from memory.common import qdrant +from memory.common.db.connection import make_session +from memory.common.db.models import SourceItem +from memory.common.db.models.source_items import AgentObservation + +logger = logging.getLogger(__name__) + +meta_mcp = FastMCP("memory-meta") + +# Auth provider will be injected at mount time +_get_current_user = None + + +def set_auth_provider(get_current_user_func): + """Set the authentication provider function.""" + global _get_current_user + _get_current_user = get_current_user_func + + +def get_current_user() -> dict: + """Get the current authenticated user.""" + if _get_current_user is None: + return {"authenticated": False, "error": "Auth provider not configured"} + return _get_current_user() + + +# --- Metadata tools --- + + +class SchemaArg(TypedDict): + type: str | None + description: str | None + + +class CollectionMetadata(TypedDict): + schema: dict[str, SchemaArg] + size: int + + +def from_annotation(annotation: Annotated) -> SchemaArg | None: + try: + type_, description = get_args(annotation) + type_str = str(type_) + if type_str.startswith("typing."): + type_str = type_str[7:] + elif len((parts := type_str.split("'"))) > 1: + type_str = parts[1] + return SchemaArg(type=type_str, description=description) + except IndexError: + logger.error(f"Error from annotation: {annotation}") + return None + + +def get_schema(klass: type[SourceItem]) -> dict[str, SchemaArg]: + if not hasattr(klass, "as_payload"): + return {} + + if not (payload_type := get_type_hints(klass.as_payload).get("return")): + return {} + + return { + name: schema + for name, arg in payload_type.__annotations__.items() + if (schema := from_annotation(arg)) + } + + +@meta_mcp.tool() +async def get_metadata_schemas() -> dict[str, CollectionMetadata]: + """Get the metadata schema for each collection used in the knowledge base. + + These schemas can be used to filter the knowledge base. + + Returns: A mapping of collection names to their metadata schemas with field types and descriptions. + + Example: + ``` + { + "mail": {"subject": {"type": "str", "description": "The subject of the email."}}, + "chat": {"subject": {"type": "str", "description": "The subject of the chat message."}} + } + """ + client = qdrant.get_qdrant_client() + sizes = qdrant.get_collection_sizes(client) + schemas = defaultdict(dict) + for klass in SourceItem.__subclasses__(): + for collection in klass.get_collections(): + schemas[collection].update(get_schema(klass)) + + return { + collection: CollectionMetadata(schema=schema, size=size) + for collection, schema in schemas.items() + if (size := sizes.get(collection)) + } + + +@meta_mcp.tool() +async def get_all_tags() -> list[str]: + """Get all unique tags used across the entire knowledge base. + + Returns sorted list of tags from both observations and content. + """ + with make_session() as session: + tags_query = session.query(func.unnest(SourceItem.tags)).distinct() + return sorted({row[0] for row in tags_query if row[0] is not None}) + + +@meta_mcp.tool() +async def get_all_subjects() -> list[str]: + """Get all unique subjects from observations about the user. + + Returns sorted list of subject identifiers used in observations. + """ + with make_session() as session: + return sorted( + r.subject for r in session.query(AgentObservation.subject).distinct() + ) + + +@meta_mcp.tool() +async def get_all_observation_types() -> list[str]: + """Get all observation types that have been used. + + Standard types are belief, preference, behavior, contradiction, general, but there can be more. + """ + with make_session() as session: + return sorted( + { + r.observation_type + for r in session.query(AgentObservation.observation_type).distinct() + if r.observation_type is not None + } + ) + + +# --- Utility tools --- + + +@meta_mcp.tool() +async def get_current_time() -> dict: + """Get the current time in UTC.""" + logger.info("get_current_time tool called") + return {"current_time": datetime.now(timezone.utc).isoformat()} + + +@meta_mcp.tool() +async def get_authenticated_user() -> dict: + """Get information about the authenticated user.""" + return get_current_user() + + +# --- Forecasting tools --- + + +class BinaryProbs(TypedDict): + prob: float + + +class MultiProbs(TypedDict): + answerProbs: dict[str, float] + + +Probs = dict[str, BinaryProbs | MultiProbs] +OutcomeType = Literal["BINARY", "MULTIPLE_CHOICE"] + + +class MarketAnswer(TypedDict): + id: str + text: str + resolutionProbability: float + + +class MarketDetails(TypedDict): + id: str + createdTime: int + question: str + outcomeType: OutcomeType + textDescription: str + groupSlugs: list[str] + volume: float + isResolved: bool + answers: list[MarketAnswer] + + +class Market(TypedDict): + id: str + url: str + question: str + volume: int + createdTime: int + outcomeType: OutcomeType + createdAt: NotRequired[str] + description: NotRequired[str] + answers: NotRequired[dict[str, float]] + probability: NotRequired[float] + details: NotRequired[MarketDetails] + + +async def get_details(session: aiohttp.ClientSession, market_id: str): + async with session.get( + f"https://api.manifold.markets/v0/market/{market_id}" + ) as resp: + resp.raise_for_status() + return await resp.json() + + +async def format_market(session: aiohttp.ClientSession, market: Market): + if market.get("outcomeType") != "BINARY": + details = await get_details(session, market["id"]) + market["answers"] = { + answer["text"]: round( + answer.get("resolutionProbability") or answer.get("probability") or 0, 3 + ) + for answer in details["answers"] + } + if creationTime := market.get("createdTime"): + market["createdAt"] = datetime.fromtimestamp(creationTime / 1000).isoformat() + + fields = [ + "id", + "name", + "url", + "question", + "volume", + "createdAt", + "details", + "probability", + "answers", + ] + return {k: v for k, v in market.items() if k in fields} + + +async def search_markets(term: str, min_volume: int = 1000, binary: bool = False): + async with aiohttp.ClientSession() as session: + async with session.get( + "https://api.manifold.markets/v0/search-markets", + params={ + "term": term, + "contractType": "BINARY" if binary else "ALL", + }, + ) as resp: + resp.raise_for_status() + markets = await resp.json() + + return await asyncio.gather( + *[ + format_market(session, market) + for market in markets + if market.get("volume", 0) >= min_volume + ] + ) + + +@meta_mcp.tool() +async def get_forecasts( + term: str, min_volume: int = 1000, binary: bool = False +) -> list[dict]: + """Get prediction market forecasts for a given term. + + Args: + term: The term to search for. + min_volume: The minimum volume of the market, in units of that market, so Mana for Manifold. + binary: Whether to only return binary markets. + """ + return await search_markets(term, min_volume, binary) diff --git a/src/memory/api/MCP/people.py b/src/memory/api/MCP/servers/people.py similarity index 95% rename from src/memory/api/MCP/people.py rename to src/memory/api/MCP/servers/people.py index 9982317..b92eea8 100644 --- a/src/memory/api/MCP/people.py +++ b/src/memory/api/MCP/servers/people.py @@ -1,23 +1,23 @@ -""" -MCP tools for tracking people. -""" +"""MCP subserver for tracking people.""" import logging from typing import Any +from fastmcp import FastMCP from sqlalchemy import Text from sqlalchemy import cast as sql_cast from sqlalchemy.dialects.postgresql import ARRAY -from memory.api.MCP.base import mcp -from memory.common.db.connection import make_session -from memory.common.db.models import Person +from memory.common import settings from memory.common.celery_app import SYNC_PERSON, UPDATE_PERSON from memory.common.celery_app import app as celery_app -from memory.common import settings +from memory.common.db.connection import make_session +from memory.common.db.models import Person logger = logging.getLogger(__name__) +people_mcp = FastMCP("memory-people") + def _person_to_dict(person: Person) -> dict[str, Any]: """Convert a Person model to a dictionary for API responses.""" @@ -32,7 +32,7 @@ def _person_to_dict(person: Person) -> dict[str, Any]: } -@mcp.tool() +@people_mcp.tool() async def add_person( identifier: str, display_name: str, @@ -67,7 +67,6 @@ async def add_person( """ logger.info(f"MCP: Adding person: {identifier}") - # Check if person already exists with make_session() as session: existing = session.query(Person).filter(Person.identifier == identifier).first() if existing: @@ -93,7 +92,7 @@ async def add_person( } -@mcp.tool() +@people_mcp.tool() async def update_person_info( identifier: str, display_name: str | None = None, @@ -135,7 +134,6 @@ async def update_person_info( """ logger.info(f"MCP: Updating person: {identifier}") - # Verify person exists with make_session() as session: person = session.query(Person).filter(Person.identifier == identifier).first() if not person: @@ -162,7 +160,7 @@ async def update_person_info( } -@mcp.tool() +@people_mcp.tool() async def get_person(identifier: str) -> dict | None: """ Get a person by their identifier. @@ -182,7 +180,7 @@ async def get_person(identifier: str) -> dict | None: return _person_to_dict(person) -@mcp.tool() +@people_mcp.tool() async def list_people( tags: list[str] | None = None, search: str | None = None, @@ -207,9 +205,7 @@ async def list_people( query = session.query(Person) if tags: - query = query.filter( - Person.tags.op("&&")(sql_cast(tags, ARRAY(Text))) - ) + query = query.filter(Person.tags.op("&&")(sql_cast(tags, ARRAY(Text)))) if search: search_term = f"%{search.lower()}%" @@ -225,7 +221,7 @@ async def list_people( return [_person_to_dict(p) for p in people] -@mcp.tool() +@people_mcp.tool() async def delete_person(identifier: str) -> dict: """ Delete a person by their identifier. diff --git a/src/memory/api/MCP/schedules.py b/src/memory/api/MCP/servers/schedule.py similarity index 89% rename from src/memory/api/MCP/schedules.py rename to src/memory/api/MCP/servers/schedule.py index 26b50ad..a44e9f1 100644 --- a/src/memory/api/MCP/schedules.py +++ b/src/memory/api/MCP/servers/schedule.py @@ -1,20 +1,37 @@ -""" -MCP tools for the epistemic sparring partner system. -""" +"""MCP subserver for scheduling messages and LLM calls.""" import logging from datetime import datetime, timezone from typing import Any, cast -from memory.api.MCP.base import get_current_user, mcp +from fastmcp import FastMCP + from memory.common.db.connection import make_session -from memory.common.db.models import ScheduledLLMCall, DiscordBotUser +from memory.common.db.models import DiscordBotUser, ScheduledLLMCall from memory.discord.messages import schedule_discord_message logger = logging.getLogger(__name__) +schedule_mcp = FastMCP("memory-schedule") -@mcp.tool() +# We need access to get_current_user from base - this will be injected at mount time +_get_current_user = None + + +def set_auth_provider(get_current_user_func): + """Set the authentication provider function.""" + global _get_current_user + _get_current_user = get_current_user_func + + +def get_current_user() -> dict: + """Get the current authenticated user.""" + if _get_current_user is None: + return {"authenticated": False, "error": "Auth provider not configured"} + return _get_current_user() + + +@schedule_mcp.tool() async def schedule_message( scheduled_time: str, message: str, @@ -61,10 +78,8 @@ async def schedule_message( if not discord_user and not discord_channel: raise ValueError("Either discord_user or discord_channel must be provided") - # Parse scheduled time try: scheduled_dt = datetime.fromisoformat(scheduled_time.replace("Z", "+00:00")) - # Ensure we store as naive datetime (remove timezone info for database storage) if scheduled_dt.tzinfo is not None: scheduled_dt = scheduled_dt.astimezone(timezone.utc).replace(tzinfo=None) except ValueError: @@ -98,7 +113,7 @@ async def schedule_message( } -@mcp.tool() +@schedule_mcp.tool() async def list_scheduled_llm_calls( status: str | None = None, limit: int | None = 50 ) -> dict[str, Any]: @@ -143,7 +158,7 @@ async def list_scheduled_llm_calls( } -@mcp.tool() +@schedule_mcp.tool() async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]: """ Cancel a scheduled LLM call. @@ -164,7 +179,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]: return {"error": "User not found", "user": current_user} with make_session() as session: - # Find the scheduled call scheduled_call = ( session.query(ScheduledLLMCall) .filter( @@ -180,7 +194,6 @@ async def cancel_scheduled_llm_call(scheduled_call_id: str) -> dict[str, Any]: if not scheduled_call.can_be_cancelled(): return {"error": f"Cannot cancel call with status: {scheduled_call.status}"} - # Update the status scheduled_call.status = "cancelled" session.commit() diff --git a/src/memory/api/MCP/tools.py b/src/memory/api/MCP/tools.py deleted file mode 100644 index b3efaf7..0000000 --- a/src/memory/api/MCP/tools.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -MCP tools for the epistemic sparring partner system. -""" - -import logging -from datetime import datetime, timezone - -from sqlalchemy import Text -from sqlalchemy import cast as sql_cast -from sqlalchemy.dialects.postgresql import ARRAY - -from memory.common.db.connection import make_session -from memory.common.db.models import AgentObservation, SourceItem -from memory.api.MCP.base import mcp, get_current_user - -logger = logging.getLogger(__name__) - - -def filter_observation_source_ids( - tags: list[str] | None = None, observation_types: list[str] | None = None -): - if not tags and not observation_types: - return None - - with make_session() as session: - items_query = session.query(AgentObservation.id) - - if tags: - # Use PostgreSQL array overlap operator with proper array casting - items_query = items_query.filter( - AgentObservation.tags.op("&&")(sql_cast(tags, ARRAY(Text))), - ) - if observation_types: - items_query = items_query.filter( - AgentObservation.observation_type.in_(observation_types) - ) - source_ids = [item.id for item in items_query.all()] - - return source_ids - - -def filter_source_ids( - modalities: set[str], - tags: list[str] | None = None, -): - if not tags: - return None - - with make_session() as session: - items_query = session.query(SourceItem.id) - - if tags: - # Use PostgreSQL array overlap operator with proper array casting - items_query = items_query.filter( - SourceItem.tags.op("&&")(sql_cast(tags, ARRAY(Text))), - ) - if modalities: - items_query = items_query.filter(SourceItem.modality.in_(modalities)) - source_ids = [item.id for item in items_query.all()] - - return source_ids - - -@mcp.tool() -async def get_current_time() -> dict: - """Get the current time in UTC.""" - logger.info("get_current_time tool called") - return {"current_time": datetime.now(timezone.utc).isoformat()} - - -@mcp.tool() -async def get_authenticated_user() -> dict: - """Get information about the authenticated user.""" - return get_current_user() diff --git a/src/memory/api/app.py b/src/memory/api/app.py index 959f276..44911a3 100644 --- a/src/memory/api/app.py +++ b/src/memory/api/app.py @@ -2,7 +2,6 @@ FastAPI application for the knowledge base. """ -import contextlib import os import logging import mimetypes @@ -35,15 +34,10 @@ limiter = Limiter( enabled=settings.API_RATE_LIMIT_ENABLED, ) +# Create the MCP http app to get its lifespan +mcp_http_app = mcp.http_app(stateless_http=True) -@contextlib.asynccontextmanager -async def lifespan(app: FastAPI): - async with contextlib.AsyncExitStack() as stack: - await stack.enter_async_context(mcp.session_manager.run()) - yield - - -app = FastAPI(title="Knowledge Base API", lifespan=lifespan) +app = FastAPI(title="Knowledge Base API", lifespan=mcp_http_app.lifespan) app.state.limiter = limiter # Rate limit exception handler @@ -158,46 +152,9 @@ app.include_router(auth_router) # Add health check to MCP server instead of main app -@mcp.custom_route("/health", methods=["GET"]) -async def health_check(request: Request): - """Health check endpoint that verifies all dependencies are accessible.""" - from fastapi.responses import JSONResponse - from sqlalchemy import text - - checks = {"mcp_oauth": "enabled"} - all_healthy = True - - # Check database connection - try: - with engine.connect() as conn: - conn.execute(text("SELECT 1")) - checks["database"] = "healthy" - except Exception as e: - # Log error details but don't expose in response - logger.error(f"Database health check failed: {e}") - checks["database"] = "unhealthy" - all_healthy = False - - # Check Qdrant connection - try: - from memory.common.qdrant import get_qdrant_client - - client = get_qdrant_client() - client.get_collections() - checks["qdrant"] = "healthy" - except Exception as e: - # Log error details but don't expose in response - logger.error(f"Qdrant health check failed: {e}") - checks["qdrant"] = "unhealthy" - all_healthy = False - - checks["status"] = "healthy" if all_healthy else "degraded" - status_code = 200 if all_healthy else 503 - return JSONResponse(checks, status_code=status_code) - - # Mount MCP server at root - OAuth endpoints need to be at root level -app.mount("/", mcp.streamable_http_app()) +# Health check is defined in MCP/base.py +app.mount("/", mcp_http_app) def main(reload: bool = False): diff --git a/tools/deploy.sh b/tools/deploy.sh index 37031be..04af30a 100755 --- a/tools/deploy.sh +++ b/tools/deploy.sh @@ -58,7 +58,6 @@ sync_code() { "$PROJECT_DIR/frontend" \ "$PROJECT_DIR/requirements" \ "$PROJECT_DIR/setup.py" \ - "$PROJECT_DIR/pyproject.toml" \ "$PROJECT_DIR/docker-compose.yaml" \ "$PROJECT_DIR/pytest.ini" \ "$REMOTE_HOST:$REMOTE_DIR/" diff --git a/tools/diagnose.sh b/tools/diagnose.sh new file mode 100755 index 0000000..4328cbf --- /dev/null +++ b/tools/diagnose.sh @@ -0,0 +1,195 @@ +#!/bin/bash +set -e + +REMOTE_HOST="memory" +REMOTE_DIR="/home/ec2-user/memory" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +usage() { + echo "Usage: $0 [options]" + echo "" + echo "Safe diagnostic commands for the memory server." + echo "" + echo "Commands:" + echo " logs [service] [lines] View docker logs (default: all services, 100 lines)" + echo " ps Show docker container status" + echo " disk Show disk usage" + echo " mem Show memory usage" + echo " top Show running processes" + echo " ls List directory contents" + echo " cat View file contents" + echo " tail [lines] Tail a file (default: 50 lines)" + echo " grep Search for pattern in files" + echo " db Run read-only SQL query" + echo " get [port] GET request to localhost (default port: 8000)" + echo " status Overall system status" + exit 1 +} + +remote() { + ssh "$REMOTE_HOST" "$@" +} + +docker_logs() { + local service="${1:-}" + local lines="${2:-100}" + if [ -n "$service" ]; then + echo -e "${GREEN}Logs for $service (last $lines lines):${NC}" + remote "cd $REMOTE_DIR && docker compose logs --tail=$lines $service" + else + echo -e "${GREEN}All logs (last $lines lines):${NC}" + remote "cd $REMOTE_DIR && docker compose logs --tail=$lines" + fi +} + +docker_ps() { + echo -e "${GREEN}Container status:${NC}" + remote "cd $REMOTE_DIR && docker compose ps" +} + +disk_usage() { + echo -e "${GREEN}Disk usage:${NC}" + remote "df -h && echo '' && du -sh $REMOTE_DIR/* 2>/dev/null | sort -h" +} + +mem_usage() { + echo -e "${GREEN}Memory usage:${NC}" + remote "free -h" +} + +show_top() { + echo -e "${GREEN}Top processes:${NC}" + remote "ps aux --sort=-%mem | head -20" +} + +list_dir() { + local path="${1:-.}" + # Ensure path is within project directory for safety + echo -e "${GREEN}Contents of $path:${NC}" + remote "cd $REMOTE_DIR && ls -la $path" +} + +cat_file() { + local file="$1" + if [ -z "$file" ]; then + echo -e "${RED}Error: No file specified${NC}" + exit 1 + fi + echo -e "${GREEN}Contents of $file:${NC}" + remote "cd $REMOTE_DIR && cat $file" +} + +tail_file() { + local file="$1" + local lines="${2:-50}" + if [ -z "$file" ]; then + echo -e "${RED}Error: No file specified${NC}" + exit 1 + fi + echo -e "${GREEN}Last $lines lines of $file:${NC}" + remote "cd $REMOTE_DIR && tail -n $lines $file" +} + +grep_files() { + local pattern="$1" + local path="${2:-.}" + if [ -z "$pattern" ]; then + echo -e "${RED}Error: No pattern specified${NC}" + exit 1 + fi + echo -e "${GREEN}Searching for '$pattern' in $path:${NC}" + remote "cd $REMOTE_DIR && grep -r --color=always '$pattern' $path || true" +} + +db_query() { + local query="$1" + if [ -z "$query" ]; then + echo -e "${RED}Error: No query specified${NC}" + exit 1 + fi + # Only allow SELECT queries for safety + if ! echo "$query" | grep -qi "^select"; then + echo -e "${RED}Error: Only SELECT queries are allowed${NC}" + exit 1 + fi + echo -e "${GREEN}Running query:${NC}" + remote "cd $REMOTE_DIR && docker compose exec -T postgres psql -U kb -d kb -c \"$query\"" +} + +http_get() { + local path="$1" + local port="${2:-8000}" + if [ -z "$path" ]; then + echo -e "${RED}Error: No path specified${NC}" + exit 1 + fi + # Ensure path starts with / + if [[ "$path" != /* ]]; then + path="/$path" + fi + echo -e "${GREEN}GET http://localhost:${port}${path}${NC}" + remote "curl -s -w '\n\nHTTP Status: %{http_code}\n' 'http://localhost:${port}${path}'" +} + +system_status() { + echo -e "${GREEN}=== System Status ===${NC}" + echo "" + docker_ps + echo "" + echo -e "${GREEN}=== Memory ===${NC}" + mem_usage + echo "" + echo -e "${GREEN}=== Disk ===${NC}" + remote "df -h /" + echo "" + echo -e "${GREEN}=== Recent Errors (last 20 lines) ===${NC}" + remote "cd $REMOTE_DIR && docker compose logs --tail=100 2>&1 | grep -i -E '(error|exception|failed|fatal)' | tail -20 || echo 'No recent errors found'" +} + +# Main +case "${1:-}" in + logs) + docker_logs "${2:-}" "${3:-}" + ;; + ps) + docker_ps + ;; + disk) + disk_usage + ;; + mem) + mem_usage + ;; + top) + show_top + ;; + ls) + list_dir "${2:-}" + ;; + cat) + cat_file "${2:-}" + ;; + tail) + tail_file "${2:-}" "${3:-}" + ;; + grep) + grep_files "${2:-}" "${3:-}" + ;; + db) + db_query "${2:-}" + ;; + get) + http_get "${2:-}" "${3:-}" + ;; + status) + system_status + ;; + *) + usage + ;; +esac