mirror of
https://github.com/mruwnik/memory.git
synced 2025-07-30 06:36:07 +02:00
allow claude scopes
This commit is contained in:
parent
50601ad930
commit
42dd94c0ef
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@ -18,7 +17,11 @@ from starlette.requests import Request
|
|||||||
from starlette.responses import JSONResponse, RedirectResponse
|
from starlette.responses import JSONResponse, RedirectResponse
|
||||||
from starlette.templating import Jinja2Templates
|
from starlette.templating import Jinja2Templates
|
||||||
|
|
||||||
from memory.api.MCP.oauth_provider import SimpleOAuthProvider
|
from memory.api.MCP.oauth_provider import (
|
||||||
|
SimpleOAuthProvider,
|
||||||
|
ALLOWED_SCOPES,
|
||||||
|
BASE_SCOPES,
|
||||||
|
)
|
||||||
from memory.common import settings
|
from memory.common import settings
|
||||||
from memory.common.db.connection import make_session
|
from memory.common.db.connection import make_session
|
||||||
from memory.common.db.models import OAuthState
|
from memory.common.db.models import OAuthState
|
||||||
@ -61,13 +64,13 @@ templates = Jinja2Templates(directory=template_dir)
|
|||||||
oauth_provider = SimpleOAuthProvider()
|
oauth_provider = SimpleOAuthProvider()
|
||||||
auth_settings = AuthSettings(
|
auth_settings = AuthSettings(
|
||||||
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
issuer_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
||||||
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL),
|
resource_server_url=cast(AnyHttpUrl, settings.SERVER_URL), # type: ignore
|
||||||
client_registration_options=ClientRegistrationOptions(
|
client_registration_options=ClientRegistrationOptions(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
valid_scopes=["read", "write"],
|
valid_scopes=ALLOWED_SCOPES,
|
||||||
default_scopes=["read"],
|
default_scopes=BASE_SCOPES,
|
||||||
),
|
),
|
||||||
required_scopes=["read", "write"],
|
required_scopes=BASE_SCOPES,
|
||||||
)
|
)
|
||||||
|
|
||||||
mcp = FastMCP(
|
mcp = FastMCP(
|
||||||
@ -81,8 +84,8 @@ mcp = FastMCP(
|
|||||||
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
|
@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
|
||||||
async def oauth_protected_resource(request: Request):
|
async def oauth_protected_resource(request: Request):
|
||||||
"""OAuth 2.0 Protected Resource Metadata."""
|
"""OAuth 2.0 Protected Resource Metadata."""
|
||||||
logger.info("Protected resource metadata requested")
|
metadata = oauth_provider.get_protected_resource_metadata()
|
||||||
return JSONResponse(oauth_provider.get_protected_resource_metadata())
|
return JSONResponse(metadata)
|
||||||
|
|
||||||
|
|
||||||
def login_form(request: Request, form_data: dict, error: str | None = None):
|
def login_form(request: Request, form_data: dict, error: str | None = None):
|
||||||
|
@ -27,6 +27,10 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ALLOWED_SCOPES = ["read", "write", "claudeai"]
|
||||||
|
BASE_SCOPES = ["read"]
|
||||||
|
RW_SCOPES = ["read", "write"]
|
||||||
|
|
||||||
|
|
||||||
# Token configuration constants
|
# Token configuration constants
|
||||||
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
|
ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days
|
||||||
@ -164,6 +168,36 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
)
|
)
|
||||||
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
|
raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}")
|
||||||
|
|
||||||
|
# Determine which scopes to grant
|
||||||
|
requested_scopes = getattr(params, "scopes", None) or []
|
||||||
|
|
||||||
|
if not requested_scopes:
|
||||||
|
# Use default scopes if none requested
|
||||||
|
requested_scopes = BASE_SCOPES
|
||||||
|
|
||||||
|
# Validate requested scopes are allowed
|
||||||
|
requested_scopes_set = set(requested_scopes)
|
||||||
|
allowed_scopes_set = set(ALLOWED_SCOPES)
|
||||||
|
|
||||||
|
if not requested_scopes_set.issubset(allowed_scopes_set):
|
||||||
|
invalid_scopes = requested_scopes_set - allowed_scopes_set
|
||||||
|
raise ValueError(f"Invalid scopes: {', '.join(invalid_scopes)}")
|
||||||
|
|
||||||
|
# Check if requested scopes are in client's registered scopes
|
||||||
|
client_scopes = (
|
||||||
|
getattr(client, "scope", "").split() if hasattr(client, "scope") else []
|
||||||
|
)
|
||||||
|
client_scopes_set = set(client_scopes)
|
||||||
|
|
||||||
|
if client_scopes and not requested_scopes_set.issubset(client_scopes_set):
|
||||||
|
invalid_scopes = requested_scopes_set - client_scopes_set
|
||||||
|
logger.error(
|
||||||
|
f"❌ Client was not registered with scope(s): {invalid_scopes}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Client was not registered with scope {', '.join(invalid_scopes)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Store the authorization parameters in database
|
# Store the authorization parameters in database
|
||||||
with make_session() as session:
|
with make_session() as session:
|
||||||
oauth_state = OAuthState(
|
oauth_state = OAuthState(
|
||||||
@ -175,7 +209,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
).lower()
|
).lower()
|
||||||
== "true",
|
== "true",
|
||||||
code_challenge=params.code_challenge or "",
|
code_challenge=params.code_challenge or "",
|
||||||
scopes=["read", "write"], # Default scopes
|
scopes=requested_scopes,
|
||||||
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
|
expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry
|
||||||
)
|
)
|
||||||
session.add(oauth_state)
|
session.add(oauth_state)
|
||||||
@ -240,6 +274,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
if not auth_code:
|
if not auth_code:
|
||||||
logger.error(f"Invalid authorization code: {authorization_code}")
|
logger.error(f"Invalid authorization code: {authorization_code}")
|
||||||
raise ValueError("Invalid authorization code")
|
raise ValueError("Invalid authorization code")
|
||||||
|
|
||||||
return AuthorizationCode(**auth_code.serialize(code=True))
|
return AuthorizationCode(**auth_code.serialize(code=True))
|
||||||
|
|
||||||
async def exchange_authorization_code(
|
async def exchange_authorization_code(
|
||||||
@ -272,7 +307,10 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
|
|
||||||
# Query for active (non-expired) session
|
# Query for active (non-expired) session
|
||||||
user_session = session.query(UserSession).get(token)
|
user_session = session.query(UserSession).get(token)
|
||||||
if not user_session or user_session.expires_at < now:
|
if not user_session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if user_session.expires_at < now:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return AccessToken(
|
return AccessToken(
|
||||||
@ -390,7 +428,7 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
"""Return metadata about the protected resource."""
|
"""Return metadata about the protected resource."""
|
||||||
return {
|
return {
|
||||||
"resource_server": settings.SERVER_URL,
|
"resource_server": settings.SERVER_URL,
|
||||||
"scopes_supported": ["read", "write"],
|
"scopes_supported": ALLOWED_SCOPES,
|
||||||
"bearer_methods_supported": ["header"],
|
"bearer_methods_supported": ["header"],
|
||||||
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
"resource_documentation": f"{settings.SERVER_URL}/docs",
|
||||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
@ -399,12 +437,12 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
"protected_resources": [
|
"protected_resources": [
|
||||||
{
|
{
|
||||||
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
"resource_uri": f"{settings.SERVER_URL}/mcp",
|
||||||
"scopes": ["read", "write"],
|
"scopes": RW_SCOPES,
|
||||||
"http_methods": ["POST", "GET"],
|
"http_methods": ["POST", "GET"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"resource_uri": f"{settings.SERVER_URL}/mcp/",
|
"resource_uri": f"{settings.SERVER_URL}/mcp/",
|
||||||
"scopes": ["read", "write"],
|
"scopes": RW_SCOPES,
|
||||||
"http_methods": ["POST", "GET"],
|
"http_methods": ["POST", "GET"],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user