diff --git a/db/migrations/versions/20250606_123611_oauth_codes.py b/db/migrations/versions/20250606_123611_oauth_codes.py new file mode 100644 index 0000000..b2a91df --- /dev/null +++ b/db/migrations/versions/20250606_123611_oauth_codes.py @@ -0,0 +1,86 @@ +"""oauth codes + +Revision ID: 66771d293b27 +Revises: 58439dd3088b +Create Date: 2025-06-06 12:36:11.737507 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "66771d293b27" +down_revision: Union[str, None] = "58439dd3088b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "oauth_client", + sa.Column("client_id", sa.String(), nullable=False), + sa.Column("client_secret", sa.String(), nullable=True), + sa.Column("client_id_issued_at", sa.Numeric(), nullable=False), + sa.Column("client_secret_expires_at", sa.Numeric(), nullable=True), + sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=False), + sa.Column("token_endpoint_auth_method", sa.String(), nullable=False), + sa.Column("grant_types", sa.ARRAY(sa.String()), nullable=False), + sa.Column("response_types", sa.ARRAY(sa.String()), nullable=False), + sa.Column("scope", sa.String(), nullable=False), + sa.Column("client_name", sa.String(), nullable=False), + sa.Column("client_uri", sa.String(), nullable=True), + sa.Column("logo_uri", sa.String(), nullable=True), + sa.Column("contacts", sa.ARRAY(sa.String()), nullable=True), + sa.Column("tos_uri", sa.String(), nullable=True), + sa.Column("policy_uri", sa.String(), nullable=True), + sa.Column("jwks_uri", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("client_id"), + ) + op.create_table( + "oauth_states", + sa.Column("state", sa.String(), nullable=False), + sa.Column("client_id", sa.String(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("code", sa.String(), nullable=True), + sa.Column("redirect_uri", sa.String(), nullable=False), + sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False), + sa.Column("code_challenge", sa.String(), nullable=True), + sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False), + sa.Column( + "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True + ), + sa.Column("expires_at", sa.DateTime(), nullable=False), + sa.Column("stale", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["client_id"], + ["oauth_client.client_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("state"), + ) + op.add_column( + "user_sessions", sa.Column("oauth_state_id", sa.String(), nullable=True) + ) + op.create_foreign_key( + "fk_user_sessions_oauth_state_id", + "user_sessions", + "oauth_states", + ["oauth_state_id"], + ["state"], + ) + + +def downgrade() -> None: + op.drop_constraint( + "fk_user_sessions_oauth_state_id", "user_sessions", type_="foreignkey" + ) + op.drop_column("user_sessions", "oauth_state_id") + op.drop_table("oauth_states") + op.drop_table("oauth_client") diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py new file mode 100644 index 0000000..755c0c2 --- /dev/null +++ b/src/memory/api/MCP/base.py @@ -0,0 +1,124 @@ +import logging +import os +from typing import cast + +from mcp.server.auth.handlers.authorize import AuthorizationRequest +from mcp.server.auth.handlers.token import ( + AuthorizationCodeRequest, + RefreshTokenRequest, + TokenRequest, +) +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp import FastMCP +from mcp.shared.auth import OAuthClientMetadata +from memory.common.db.models.users import User +from pydantic import AnyHttpUrl +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse +from starlette.templating import Jinja2Templates + +from memory.api.MCP.oauth_provider import SimpleOAuthProvider +from memory.common import settings +from memory.common.db.connection import make_session +from memory.common.db.models import OAuthState + +logger = logging.getLogger(__name__) + + +def validate_metadata(klass: type): + orig_validate = klass.model_validate + + def validate(data: dict): + data = dict(data) + if "redirect_uris" in data: + data["redirect_uris"] = [ + str(uri).replace("cursor://", "http://") + for uri in data["redirect_uris"] + ] + if "redirect_uri" in data: + data["redirect_uri"] = str(data["redirect_uri"]).replace( + "cursor://", "http://" + ) + + return orig_validate(data) + + klass.model_validate = validate + + +validate_metadata(OAuthClientMetadata) +validate_metadata(AuthorizationRequest) +validate_metadata(AuthorizationCodeRequest) +validate_metadata(RefreshTokenRequest) +validate_metadata(TokenRequest) + + +# Setup templates +template_dir = os.path.join(os.path.dirname(__file__), "templates") +templates = Jinja2Templates(directory=template_dir) + + +oauth_provider = SimpleOAuthProvider() +auth_settings = AuthSettings( + issuer_url=cast(AnyHttpUrl, settings.SERVER_URL), + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=["read", "write"], + default_scopes=["read"], + ), + required_scopes=["read", "write"], +) + +mcp = FastMCP( + "memory", + stateless_http=True, + auth_server_provider=oauth_provider, + auth=auth_settings, +) + + +@mcp.custom_route("/.well-known/oauth-protected-resource", methods=["GET"]) +async def oauth_protected_resource(request: Request): + """OAuth 2.0 Protected Resource Metadata.""" + logger.info("Protected resource metadata requested") + return JSONResponse(oauth_provider.get_protected_resource_metadata()) + + +def login_form(request: Request, form_data: dict, error: str | None = None): + return templates.TemplateResponse( + "login.html", + {"request": request, "form_data": form_data, "error": error}, + ) + + +@mcp.custom_route("/oauth/login", methods=["GET"]) +async def login_page(request: Request): + """Display the login page.""" + form_data = dict(request.query_params) + + state = form_data.get("state") + with make_session() as session: + oauth_state = session.query(OAuthState).get(state) + if not oauth_state: + logger.error(f"State {state} not found in database") + raise ValueError("Invalid state parameter") + + return login_form(request, form_data, None) + + +@mcp.custom_route("/oauth/login", methods=["POST"]) +async def handle_login(request: Request): + """Handle login form submission.""" + form = await request.form() + oauth_params = { + key: value for key, value in form.items() if key not in ["email", "password"] + } + with make_session() as session: + user = session.query(User).filter(User.email == form.get("email")).first() + if not user or not user.is_valid_password(str(form.get("password", ""))): + logger.warning("Login failed - invalid credentials") + return login_form(request, oauth_params, "Invalid email or password") + + redirect_url = await oauth_provider.complete_authorization(oauth_params, user) + if redirect_url.startswith("http://anysphere.cursor-retrieval"): + redirect_url = redirect_url.replace("http://", "cursor://") + return RedirectResponse(url=redirect_url, status_code=302) diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py new file mode 100644 index 0000000..e5b52d1 --- /dev/null +++ b/src/memory/api/MCP/oauth_provider.py @@ -0,0 +1,246 @@ +import secrets +import time +from typing import Optional, Any, cast +from urllib.parse import urlencode +import logging +from datetime import datetime, timezone + +from memory.common.db.models.users import ( + User, + UserSession, + OAuthClientInformation, + OAuthState, +) +from memory.common.db.connection import make_session +from memory.common import settings +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationParams, + AuthorizationCode, + OAuthAuthorizationServerProvider, + RefreshToken, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +logger = logging.getLogger(__name__) + + +class SimpleOAuthProvider(OAuthAuthorizationServerProvider): + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """Get OAuth client information.""" + with make_session() as session: + client = session.query(OAuthClientInformation).get(client_id) + return client and OAuthClientInformationFull(**client.serialize()) + + async def register_client(self, client_info: OAuthClientInformationFull): + """Register a new OAuth client.""" + with make_session() as session: + client = session.query(OAuthClientInformation).get(client_info.client_id) + if not client: + client = OAuthClientInformation(client_id=client_info.client_id) + + for key, value in client_info.model_dump().items(): + if key == "redirect_uris": + value = [str(uri) for uri in value] + elif value and key in [ + "client_uri", + "logo_uri", + "tos_uri", + "policy_uri", + "jwks_uri", + ]: + value = str(value) + setattr(client, key, value) + session.add(client) + session.commit() + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + """Redirect to login page for user authentication.""" + redirect_uri_str = str(params.redirect_uri) + registered_uris = [str(uri) for uri in getattr(client, "redirect_uris", [])] + if redirect_uri_str not in registered_uris: + logger.error( + f"Redirect URI {redirect_uri_str} not in registered URIs: {registered_uris}" + ) + raise ValueError(f"Invalid redirect_uri: {redirect_uri_str}") + + # Store the authorization parameters in database + with make_session() as session: + oauth_state = OAuthState( + state=params.state or secrets.token_hex(16), + client_id=client.client_id, + redirect_uri=str(params.redirect_uri), + redirect_uri_provided_explicitly=str( + params.redirect_uri_provided_explicitly + ).lower() + == "true", + code_challenge=params.code_challenge or "", + scopes=["read", "write"], # Default scopes + expires_at=datetime.fromtimestamp(time.time() + 600), # 10 min expiry + ) + session.add(oauth_state) + session.commit() + + return f"{settings.SERVER_URL}/oauth/login?" + urlencode( + { + "state": oauth_state.state, + "client_id": client.client_id, + "redirect_uri": oauth_state.redirect_uri, + "redirect_uri_provided_explicitly": oauth_state.redirect_uri_provided_explicitly, + "code_challenge": cast(str, oauth_state.code_challenge), + } + ) + + async def complete_authorization(self, oauth_params: dict, user: User) -> str: + """Complete authorization after successful login.""" + if not (state := oauth_params.get("state")): + logger.error("No state parameter provided") + raise ValueError("Missing state parameter") + + with make_session() as session: + # Load OAuth state from database + oauth_state = session.query(OAuthState).get(state) + if not oauth_state: + logger.error(f"State {state} not found in database") + raise ValueError("Invalid state parameter") + + # Check if state has expired + now = datetime.fromtimestamp(time.time()) + if oauth_state.expires_at < now: + logger.error(f"State {state} has expired") + oauth_state.stale = True + session.commit() + raise ValueError("State has expired") + + oauth_state.code = f"code_{secrets.token_hex(16)}" + oauth_state.stale = False + oauth_state.user_id = user.id + + session.add(oauth_state) + session.commit() + + return construct_redirect_uri( + oauth_state.redirect_uri, code=oauth_state.code, state=state + ) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> Optional[AuthorizationCode]: + """Load an authorization code.""" + with make_session() as session: + auth_code = ( + session.query(OAuthState) + .filter(OAuthState.code == authorization_code) + .first() + ) + if not auth_code: + logger.error(f"Invalid authorization code: {authorization_code}") + raise ValueError("Invalid authorization code") + return AuthorizationCode(**auth_code.serialize(code=True)) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Exchange authorization code for tokens.""" + with make_session() as session: + auth_code = ( + session.query(OAuthState) + .filter(OAuthState.code == authorization_code.code) + .first() + ) + + if not auth_code: + logger.error(f"Invalid authorization code: {authorization_code.code}") + raise ValueError("Invalid authorization code") + + # Get the user associated with this auth code + if not auth_code.user: + logger.error(f"No user found for auth code: {authorization_code.code}") + raise ValueError("Invalid authorization code") + + # Create a UserSession to serve as access token + expires_at = datetime.fromtimestamp(time.time() + 3600) + + auth_code.session = UserSession( + user_id=auth_code.user_id, + oauth_state_id=auth_code.state, + expires_at=expires_at, + ) + auth_code.stale = True # type: ignore + session.commit() + access_token = str(auth_code.session.id) + + return OAuthToken( + access_token=access_token, + token_type="bearer", + expires_in=3600, + scope=" ".join(authorization_code.scopes), + ) + + async def load_access_token(self, token: str) -> Optional[AccessToken]: + """Load and validate an access token.""" + with make_session() as session: + now = datetime.now(timezone.utc).replace( + tzinfo=None + ) # Make naive for DB comparison + + # Query for active (non-expired) session + user_session = session.query(UserSession).get(token) + if not user_session or user_session.expires_at < now: + return None + + return AccessToken( + token=token, + client_id=user_session.oauth_state.client_id, + scopes=user_session.oauth_state.scopes, + expires_at=int(user_session.expires_at.timestamp()), + ) + + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> Optional[RefreshToken]: + """Load a refresh token - not supported in this simple implementation.""" + return None + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + """Exchange refresh token - not supported in this simple implementation.""" + raise NotImplementedError("Refresh tokens not supported") + + async def revoke_token( + self, token: str, token_type_hint: Optional[str] = None + ) -> None: + """Revoke a token.""" + with make_session() as session: + user_session = session.query(UserSession).get(token) + if user_session: + session.delete(user_session) + session.commit() + + def get_protected_resource_metadata(self) -> dict[str, Any]: + """Return metadata about the protected resource.""" + return { + "resource_server": settings.SERVER_URL, + "scopes_supported": ["read", "write"], + "bearer_methods_supported": ["header"], + "resource_documentation": f"{settings.SERVER_URL}/docs", + "protected_resources": [ + { + "resource_uri": f"{settings.SERVER_URL}/mcp", + "scopes": ["read", "write"], + "http_methods": ["POST", "GET"], + }, + { + "resource_uri": f"{settings.SERVER_URL}/mcp/", + "scopes": ["read", "write"], + "http_methods": ["POST", "GET"], + }, + ], + } diff --git a/src/memory/api/MCP/templates/login.html b/src/memory/api/MCP/templates/login.html new file mode 100644 index 0000000..5ec2011 --- /dev/null +++ b/src/memory/api/MCP/templates/login.html @@ -0,0 +1,159 @@ + + + +
+ + +Access your Memory knowledge base
+