mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-29 14:16:09 +02:00
allow claude scopes
This commit is contained in:
parent
50601ad930
commit
42dd94c0ef
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from typing import cast
|
||||
|
||||
@ -18,7 +17,11 @@ from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
||||
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
||||
from memory.api.MCP.oauth_provider import (
|
||||
SimpleOAuthProvider,
|
||||
ALLOWED_SCOPES,
|
||||
BASE_SCOPES,
|
||||
)
|
||||
from memory.common import settings
|
||||
from memory.common.db.connection import make_session
|
||||
from memory.common.db.models import OAuthState
|
||||
@ -61,13 +64,13 @@ templates = Jinja2Templates(directory=template_dir)
|
||||
oauth_provider = SimpleOAuthProvider()
|
||||
auth_settings = AuthSettings(
|
||||
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
|
||||
client_registration_options=ClientRegistrationOptions(
|
||||
enabled=True,
|
||||
valid_scopes=["read", "write"],
|
||||
default_scopes=["read"],
|
||||
valid_scopes=ALLOWED_SCOPES,
|
||||
default_scopes=BASE_SCOPES,
|
||||
),
|
||||
required_scopes=["read", "write"],
|
||||
required_scopes=BASE_SCOPES,
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
@ -81,8 +84,8 @@ mcp = FastMCP(
|
||||
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
|
||||
async def oauth_protected_resource(request: Request):
|
||||
"""OAuth 2.0 Protected Resource Metadata."""
|
||||
logger.info("Protected resource metadata requested")
|
||||
return JSONResponse(oauth_provider.get_protected_resource_metadata())
|
||||
metadata = oauth_provider.get_protected_resource_metadata()
|
||||
return JSONResponse(metadata)
|
||||
|
||||
|
||||
def login_form(request: Request, form_data: dict, error: str | None = None):
|
||||
|
@ -27,6 +27,10 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_SCOPES = ["read", "write", "claudeai"]
|
||||
BASE_SCOPES = ["read"]
|
||||
RW_SCOPES = ["read", "write"]
|
||||
|
||||
|
||||
# Token configuration constants
|
||||
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
|
||||
@ -164,6 +168,36 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
)
|
||||
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
|
||||
|
||||
# Determine which scopes to grant
|
||||
requested_scopes = getattr(params, "scopes", None) or []
|
||||
|
||||
if not requested_scopes:
|
||||
# Use default scopes if none requested
|
||||
requested_scopes = BASE_SCOPES
|
||||
|
||||
# Validate requested scopes are allowed
|
||||
requested_scopes_set = set(requested_scopes)
|
||||
allowed_scopes_set = set(ALLOWED_SCOPES)
|
||||
|
||||
if not requested_scopes_set.issubset(allowed_scopes_set):
|
||||
invalid_scopes = requested_scopes_set - allowed_scopes_set
|
||||
raise ValueError(f"Invalid scopes: {', '.join(invalid_scopes)}")
|
||||
|
||||
# Check if requested scopes are in client's registered scopes
|
||||
client_scopes = (
|
||||
getattr(client, "scope", "").split() if hasattr(client, "scope") else []
|
||||
)
|
||||
client_scopes_set = set(client_scopes)
|
||||
|
||||
if client_scopes and not requested_scopes_set.issubset(client_scopes_set):
|
||||
invalid_scopes = requested_scopes_set - client_scopes_set
|
||||
logger.error(
|
||||
f"❌ Client was not registered with scope(s): {invalid_scopes}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Client was not registered with scope {', '.join(invalid_scopes)}"
|
||||
)
|
||||
|
||||
# Store the authorization parameters in database
|
||||
with make_session() as session:
|
||||
oauth_state = OAuthState(
|
||||
@ -175,7 +209,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
).lower()
|
||||
== "true",
|
||||
code_challenge=params.code_challenge or "",
|
||||
scopes=["read", "write"], # Default scopes
|
||||
scopes=requested_scopes,
|
||||
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
|
||||
)
|
||||
session.add(oauth_state)
|
||||
@ -240,6 +274,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
if not auth_code:
|
||||
logger.error(f"Invalid authorization code: {authorization_code}")
|
||||
raise ValueError("Invalid authorization code")
|
||||
|
||||
return AuthorizationCode(**auth_code.serialize(code=True))
|
||||
|
||||
async def exchange_authorization_code(
|
||||
@ -272,7 +307,10 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
|
||||
# Query for active (non-expired) session
|
||||
user_session = session.query(UserSession).get(token)
|
||||
if not user_session or user_session.expires_at < now:
|
||||
if not user_session:
|
||||
return None
|
||||
|
||||
if user_session.expires_at < now:
|
||||
return None
|
||||
|
||||
return AccessToken(
|
||||
@ -390,7 +428,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
"""Return metadata about the protected resource."""
|
||||
return {
|
||||
"resource_server": settings.SERVER_URL,
|
||||
"scopes_supported": ["read", "write"],
|
||||
"scopes_supported": ALLOWED_SCOPES,
|
||||
"bearer_methods_supported": ["header"],
|
||||
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
@ -399,12 +437,12 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
"protected_resources": [
|
||||
{
|
||||
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
||||
"scopes": ["read", "write"],
|
||||
"scopes": RW_SCOPES,
|
||||
"http_methods": ["POST", "GET"],
|
||||
},
|
||||
{
|
||||
"resource_uri": f"{settings.SERVER_URL}/mcp/",
|
||||
"scopes": ["read", "write"],
|
||||
"scopes": RW_SCOPES,
|
||||
"http_methods": ["POST", "GET"],
|
||||
},
|
||||
],
|
||||
|
Loading…
x
Reference in New Issue
Block a user