diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py index 04a4b61..c6ca254 100644 --- a/src/memory/api/MCP/base.py +++ b/src/memory/api/MCP/base.py @@ -1,5 +1,4 @@ import logging -import os import pathlib from typing import cast @@ -18,7 +17,11 @@ from starlette.requests import Request from starlette.responses import JSONResponse, RedirectResponse from starlette.templating import Jinja2Templates -from memory.api.MCP.oauth_provider import SimpleOAuthProvider +from memory.api.MCP.oauth_provider import ( + SimpleOAuthProvider, + ALLOWED_SCOPES, + BASE_SCOPES, +) from memory.common import settings from memory.common.db.connection import make_session from memory.common.db.models import OAuthState @@ -61,13 +64,13 @@ templates = Jinja2Templates(directory=template_dir) oauth_provider = SimpleOAuthProvider() auth_settings = AuthSettings( issuer_url=cast(AnyHttpUrl, settings.SERVER_URL), - resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), + resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore client_registration_options=ClientRegistrationOptions( enabled=True, - valid_scopes=["read", "write"], - default_scopes=["read"], + valid_scopes=ALLOWED_SCOPES, + default_scopes=BASE_SCOPES, ), - required_scopes=["read", "write"], + required_scopes=BASE_SCOPES, ) mcp = FastMCP( @@ -81,8 +84,8 @@ mcp = FastMCP( @mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"]) async def oauth_protected_resource(request: Request): """OAuth 2.0 Protected Resource Metadata.""" - logger.info("Protected resource metadata requested") - return JSONResponse(oauth_provider.get_protected_resource_metadata()) + metadata = oauth_provider.get_protected_resource_metadata() + return JSONResponse(metadata) def login_form(request: Request, form_data: dict, error: str | None = None): diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py index ad96000..f7b8dfd 100644 --- a/src/memory/api/MCP/oauth_provider.py +++ b/src/memory/api/MCP/oauth_provider.py @@ -27,6 +27,10 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken logger = logging.getLogger(__name__) +ALLOWED_SCOPES = ["read", "write", "claudeai"] +BASE_SCOPES = ["read"] +RW_SCOPES = ["read", "write"] + # Token configuration constants ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days @@ -164,6 +168,36 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): ) raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}") + # Determine which scopes to grant + requested_scopes = getattr(params, "scopes", None) or [] + + if not requested_scopes: + # Use default scopes if none requested + requested_scopes = BASE_SCOPES + + # Validate requested scopes are allowed + requested_scopes_set = set(requested_scopes) + allowed_scopes_set = set(ALLOWED_SCOPES) + + if not requested_scopes_set.issubset(allowed_scopes_set): + invalid_scopes = requested_scopes_set - allowed_scopes_set + raise ValueError(f"Invalid scopes: {', '.join(invalid_scopes)}") + + # Check if requested scopes are in client's registered scopes + client_scopes = ( + getattr(client, "scope", "").split() if hasattr(client, "scope") else [] + ) + client_scopes_set = set(client_scopes) + + if client_scopes and not requested_scopes_set.issubset(client_scopes_set): + invalid_scopes = requested_scopes_set - client_scopes_set + logger.error( + f"❌ Client was not registered with scope(s): {invalid_scopes}" + ) + raise ValueError( + f"Client was not registered with scope {', '.join(invalid_scopes)}" + ) + # Store the authorization parameters in database with make_session() as session: oauth_state = OAuthState( @@ -175,7 +209,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): ).lower() == "true", code_challenge=params.code_challenge or "", - scopes=["read", "write"], # Default scopes + scopes=requested_scopes, expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry ) session.add(oauth_state) @@ -240,6 +274,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): if not auth_code: logger.error(f"Invalid authorization code: {authorization_code}") raise ValueError("Invalid authorization code") + return AuthorizationCode(**auth_code.serialize(code=True)) async def exchange_authorization_code( @@ -272,7 +307,10 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): # Query for active (non-expired) session user_session = session.query(UserSession).get(token) - if not user_session or user_session.expires_at < now: + if not user_session: + return None + + if user_session.expires_at < now: return None return AccessToken( @@ -390,7 +428,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): """Return metadata about the protected resource.""" return { "resource_server": settings.SERVER_URL, - "scopes_supported": ["read", "write"], + "scopes_supported": ALLOWED_SCOPES, "bearer_methods_supported": ["header"], "resource_documentation": f"{settings.SERVER_URL}/docs", "grant_types_supported": ["authorization_code", "refresh_token"], @@ -399,12 +437,12 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): "protected_resources": [ { "resource_uri": f"{settings.SERVER_URL}/mcp", - "scopes": ["read", "write"], + "scopes": RW_SCOPES, "http_methods": ["POST", "GET"], }, { "resource_uri": f"{settings.SERVER_URL}/mcp/", - "scopes": ["read", "write"], + "scopes": RW_SCOPES, "http_methods": ["POST", "GET"], }, ],