allow claude scopes

This commit is contained in:
EC2 Default User 2025-07-08 19:05:15 +00:00
parent 50601ad930
commit 42dd94c0ef
2 changed files with 54 additions and 13 deletions

View File

@ -1,5 +1,4 @@
import logging
import os
import pathlib
from typing import cast
@ -18,7 +17,11 @@ from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.templating import Jinja2Templates
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
from memory.api.MCP.oauth_provider import (
SimpleOAuthProvider,
ALLOWED_SCOPES,
BASE_SCOPES,
)
from memory.common import settings
from memory.common.db.connection import make_session
from memory.common.db.models import OAuthState
@ -61,13 +64,13 @@ templates = Jinja2Templates(directory=template_dir)
oauth_provider = SimpleOAuthProvider()
auth_settings = AuthSettings(
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL),
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
client_registration_options=ClientRegistrationOptions(
enabled=True,
valid_scopes=["read", "write"],
default_scopes=["read"],
valid_scopes=ALLOWED_SCOPES,
default_scopes=BASE_SCOPES,
),
required_scopes=["read", "write"],
required_scopes=BASE_SCOPES,
)
mcp = FastMCP(
@ -81,8 +84,8 @@ mcp = FastMCP(
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
async def oauth_protected_resource(request: Request):
"""OAuth 2.0 Protected Resource Metadata."""
logger.info("Protected resource metadata requested")
return JSONResponse(oauth_provider.get_protected_resource_metadata())
metadata = oauth_provider.get_protected_resource_metadata()
return JSONResponse(metadata)
def login_form(request: Request, form_data: dict, error: str | None = None):

View File

@ -27,6 +27,10 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
logger = logging.getLogger(__name__)
ALLOWED_SCOPES = ["read", "write", "claudeai"]
BASE_SCOPES = ["read"]
RW_SCOPES = ["read", "write"]
# Token configuration constants
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
@ -164,6 +168,36 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
)
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
# Determine which scopes to grant
requested_scopes = getattr(params, "scopes", None) or []
if not requested_scopes:
# Use default scopes if none requested
requested_scopes = BASE_SCOPES
# Validate requested scopes are allowed
requested_scopes_set = set(requested_scopes)
allowed_scopes_set = set(ALLOWED_SCOPES)
if not requested_scopes_set.issubset(allowed_scopes_set):
invalid_scopes = requested_scopes_set - allowed_scopes_set
raise ValueError(f"Invalid scopes: {', '.join(invalid_scopes)}")
# Check if requested scopes are in client's registered scopes
client_scopes = (
getattr(client, "scope", "").split() if hasattr(client, "scope") else []
)
client_scopes_set = set(client_scopes)
if client_scopes and not requested_scopes_set.issubset(client_scopes_set):
invalid_scopes = requested_scopes_set - client_scopes_set
logger.error(
f"❌ Client was not registered with scope(s): {invalid_scopes}"
)
raise ValueError(
f"Client was not registered with scope {', '.join(invalid_scopes)}"
)
# Store the authorization parameters in database
with make_session() as session:
oauth_state = OAuthState(
@ -175,7 +209,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
).lower()
== "true",
code_challenge=params.code_challenge or "",
scopes=["read", "write"], # Default scopes
scopes=requested_scopes,
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
)
session.add(oauth_state)
@ -240,6 +274,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
if not auth_code:
logger.error(f"Invalid authorization code: {authorization_code}")
raise ValueError("Invalid authorization code")
return AuthorizationCode(**auth_code.serialize(code=True))
async def exchange_authorization_code(
@ -272,7 +307,10 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
# Query for active (non-expired) session
user_session = session.query(UserSession).get(token)
if not user_session or user_session.expires_at < now:
if not user_session:
return None
if user_session.expires_at < now:
return None
return AccessToken(
@ -390,7 +428,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
"""Return metadata about the protected resource."""
return {
"resource_server": settings.SERVER_URL,
"scopes_supported": ["read", "write"],
"scopes_supported": ALLOWED_SCOPES,
"bearer_methods_supported": ["header"],
"resource_documentation": f"{settings.SERVER_URL}/docs",
"grant_types_supported": ["authorization_code", "refresh_token"],
@ -399,12 +437,12 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
"protected_resources": [
{
"resource_uri": f"{settings.SERVER_URL}/mcp",
"scopes": ["read", "write"],
"scopes": RW_SCOPES,
"http_methods": ["POST", "GET"],
},
{
"resource_uri": f"{settings.SERVER_URL}/mcp/",
"scopes": ["read", "write"],
"scopes": RW_SCOPES,
"http_methods": ["POST", "GET"],
},
],