From 986d5b9957c4884573d588def37ff553c6157a96 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Fri, 6 Jun 2025 17:07:25 +0200 Subject: [PATCH] oauth refresh + revoke --- ...odes.py => 20250606_165333_oauth_codes.py} | 51 +++- docker-compose.yaml | 1 + src/memory/api/MCP/base.py | 4 +- src/memory/api/MCP/oauth_provider.py | 234 +++++++++++++++--- src/memory/common/db/models/__init__.py | 2 + src/memory/common/db/models/users.py | 58 ++++- 6 files changed, 294 insertions(+), 56 deletions(-) rename db/migrations/versions/{20250606_123611_oauth_codes.py => 20250606_165333_oauth_codes.py} (65%) diff --git a/db/migrations/versions/20250606_123611_oauth_codes.py b/db/migrations/versions/20250606_165333_oauth_codes.py similarity index 65% rename from db/migrations/versions/20250606_123611_oauth_codes.py rename to db/migrations/versions/20250606_165333_oauth_codes.py index b2a91df..d967f5b 100644 --- a/db/migrations/versions/20250606_123611_oauth_codes.py +++ b/db/migrations/versions/20250606_165333_oauth_codes.py @@ -1,8 +1,8 @@ """oauth codes -Revision ID: 66771d293b27 +Revision ID: 1d6bc8015ea9 Revises: 58439dd3088b -Create Date: 2025-06-06 12:36:11.737507 +Create Date: 2025-06-06 16:53:33.044558 """ @@ -13,7 +13,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = "66771d293b27" +revision: str = "1d6bc8015ea9" down_revision: Union[str, None] = "58439dd3088b" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -43,18 +43,19 @@ def upgrade() -> None: op.create_table( "oauth_states", sa.Column("state", sa.String(), nullable=False), - sa.Column("client_id", sa.String(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=True), sa.Column("code", sa.String(), nullable=True), sa.Column("redirect_uri", sa.String(), nullable=False), sa.Column("redirect_uri_provided_explicitly", sa.Boolean(), nullable=False), sa.Column("code_challenge", sa.String(), nullable=True), + sa.Column("stale", sa.Boolean(), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("client_id", sa.String(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False), sa.Column( "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True ), sa.Column("expires_at", sa.DateTime(), nullable=False), - sa.Column("stale", sa.Boolean(), nullable=False), sa.ForeignKeyConstraint( ["client_id"], ["oauth_client.client_id"], @@ -63,24 +64,52 @@ def upgrade() -> None: ["user_id"], ["users.id"], ), - sa.PrimaryKeyConstraint("state"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "oauth_refresh_tokens", + sa.Column("token", sa.String(), nullable=False), + sa.Column("revoked", sa.Boolean(), nullable=False), + sa.Column("access_token_session_id", sa.String(), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("client_id", sa.String(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("scopes", sa.ARRAY(sa.String()), nullable=False), + sa.Column( + "created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True + ), + sa.Column("expires_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["access_token_session_id"], + ["user_sessions.id"], + ), + sa.ForeignKeyConstraint( + ["client_id"], + ["oauth_client.client_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), ) op.add_column( - "user_sessions", sa.Column("oauth_state_id", sa.String(), nullable=True) + "user_sessions", sa.Column("oauth_state_id", sa.Integer(), nullable=True) ) op.create_foreign_key( - "fk_user_sessions_oauth_state_id", + "user_sessions_oauth_state_id_fkey", "user_sessions", "oauth_states", ["oauth_state_id"], - ["state"], + ["id"], ) def downgrade() -> None: op.drop_constraint( - "fk_user_sessions_oauth_state_id", "user_sessions", type_="foreignkey" + "user_sessions_oauth_state_id_fkey", "user_sessions", type_="foreignkey" ) op.drop_column("user_sessions", "oauth_state_id") + op.drop_table("oauth_refresh_tokens") op.drop_table("oauth_states") op.drop_table("oauth_client") diff --git a/docker-compose.yaml b/docker-compose.yaml index 9bb5fcb..3fef575 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -147,6 +147,7 @@ services: <<: *env POSTGRES_PASSWORD_FILE: /run/secrets/postgres_password QDRANT_URL: http://qdrant:6333 + SERVER_URL: "${SERVER_URL:-http://localhost:8000}" secrets: [postgres_password] volumes: - ./memory_files:/app/memory_files:rw diff --git a/src/memory/api/MCP/base.py b/src/memory/api/MCP/base.py index 755c0c2..c3b7369 100644 --- a/src/memory/api/MCP/base.py +++ b/src/memory/api/MCP/base.py @@ -97,7 +97,9 @@ async def login_page(request: Request): state = form_data.get("state") with make_session() as session: - oauth_state = session.query(OAuthState).get(state) + oauth_state = ( + session.query(OAuthState).filter(OAuthState.state == state).first() + ) if not oauth_state: logger.error(f"State {state} not found in database") raise ValueError("Invalid state parameter") diff --git a/src/memory/api/MCP/oauth_provider.py b/src/memory/api/MCP/oauth_provider.py index e5b52d1..2405890 100644 --- a/src/memory/api/MCP/oauth_provider.py +++ b/src/memory/api/MCP/oauth_provider.py @@ -10,8 +10,10 @@ from memory.common.db.models.users import ( UserSession, OAuthClientInformation, OAuthState, + OAuthRefreshToken, + OAuthToken as TokenBase, ) -from memory.common.db.connection import make_session +from memory.common.db.connection import make_session, scoped_session from memory.common import settings from mcp.server.auth.provider import ( AccessToken, @@ -26,6 +28,101 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken logger = logging.getLogger(__name__) +# Token configuration constants +ACCESS_TOKEN_LIFETIME = 3600 * 30 * 24 # 30 days +REFRESH_TOKEN_LIFETIME = 30 * 24 * 3600 # 30 days + + +def create_expiration(lifetime_seconds: int) -> datetime: + """Create expiration datetime from lifetime in seconds.""" + return datetime.fromtimestamp(time.time() + lifetime_seconds) + + +def generate_refresh_token() -> str: + """Generate a new refresh token.""" + return f"rt_{secrets.token_hex(32)}" + + +def create_access_token_session( + user_id: int, oauth_state_id: str | None = None +) -> UserSession: + """Create a new access token session.""" + return UserSession( + user_id=user_id, + oauth_state_id=oauth_state_id, + expires_at=create_expiration(ACCESS_TOKEN_LIFETIME), + ) + + +def create_refresh_token_record( + client_id: str, + user_id: int, + scopes: list[str], + access_token_session_id: Optional[str] = None, +) -> OAuthRefreshToken: + """Create a new refresh token record.""" + return OAuthRefreshToken( + token=generate_refresh_token(), + client_id=client_id, + user_id=user_id, + scopes=scopes, + expires_at=create_expiration(REFRESH_TOKEN_LIFETIME), + access_token_session_id=access_token_session_id, + ) + + +def validate_refresh_token(db_refresh_token: OAuthRefreshToken) -> None: + """Validate a refresh token, raising ValueError if invalid.""" + now = datetime.now() + if db_refresh_token.expires_at < now: # type: ignore + logger.error(f"Refresh token expired: {db_refresh_token.token[:20]}...") + db_refresh_token.revoked = True # type: ignore + raise ValueError("Refresh token expired") + + +def create_oauth_token( + access_token: str, scopes: list[str], refresh_token: Optional[str] = None +) -> OAuthToken: + """Create an OAuth token response.""" + return OAuthToken( + access_token=access_token, + token_type="bearer", + expires_in=ACCESS_TOKEN_LIFETIME, + refresh_token=refresh_token, + scope=" ".join(scopes), + ) + + +def make_token( + db: scoped_session, + auth_state: TokenBase, + scopes: list[str], +) -> OAuthToken: + new_session = UserSession( + user_id=auth_state.user_id, + oauth_state_id=auth_state.id, + expires_at=create_expiration(ACCESS_TOKEN_LIFETIME), + ) + + # Create refresh token + refresh_token = create_refresh_token_record( + cast(str, auth_state.client_id), + cast(int, auth_state.user_id), + scopes, + cast(str, new_session.id), + ) + + db.add(new_session) + db.add(refresh_token) + db.commit() + + return create_oauth_token( + str(new_session.id), + scopes, + cast(str, refresh_token.token), + ) + + class SimpleOAuthProvider(OAuthAuthorizationServerProvider): async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """Get OAuth client information.""" @@ -102,7 +199,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): with make_session() as session: # Load OAuth state from database - oauth_state = session.query(OAuthState).get(state) + oauth_state = ( + session.query(OAuthState).filter(OAuthState.state == state).first() + ) if not oauth_state: logger.error(f"State {state} not found in database") raise ValueError("Invalid state parameter") @@ -111,19 +210,21 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): now = datetime.fromtimestamp(time.time()) if oauth_state.expires_at < now: logger.error(f"State {state} has expired") - oauth_state.stale = True + oauth_state.stale = True # type: ignore session.commit() raise ValueError("State has expired") - oauth_state.code = f"code_{secrets.token_hex(16)}" - oauth_state.stale = False + oauth_state.code = f"code_{secrets.token_hex(16)}" # type: ignore + oauth_state.stale = False # type: ignore oauth_state.user_id = user.id session.add(oauth_state) session.commit() return construct_redirect_uri( - oauth_state.redirect_uri, code=oauth_state.code, state=state + cast(str, oauth_state.redirect_uri), + code=cast(str, oauth_state.code), + state=state, ) async def load_authorization_code( @@ -156,29 +257,11 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): logger.error(f"Invalid authorization code: {authorization_code.code}") raise ValueError("Invalid authorization code") - # Get the user associated with this auth code if not auth_code.user: logger.error(f"No user found for auth code: {authorization_code.code}") raise ValueError("Invalid authorization code") - # Create a UserSession to serve as access token - expires_at = datetime.fromtimestamp(time.time() + 3600) - - auth_code.session = UserSession( - user_id=auth_code.user_id, - oauth_state_id=auth_code.state, - expires_at=expires_at, - ) - auth_code.stale = True # type: ignore - session.commit() - access_token = str(auth_code.session.id) - - return OAuthToken( - access_token=access_token, - token_type="bearer", - expires_in=3600, - scope=" ".join(authorization_code.scopes), - ) + return make_token(session, auth_code, authorization_code.scopes) async def load_access_token(self, token: str) -> Optional[AccessToken]: """Load and validate an access token.""" @@ -192,6 +275,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): if not user_session or user_session.expires_at < now: return None + logger.info( + f"Loading access token: {token}, state: {user_session.oauth_state}" + ) return AccessToken( token=token, client_id=user_session.oauth_state.client_id, @@ -202,8 +288,34 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): async def load_refresh_token( self, client: OAuthClientInformationFull, refresh_token: str ) -> Optional[RefreshToken]: - """Load a refresh token - not supported in this simple implementation.""" - return None + """Load and validate a refresh token.""" + with make_session() as session: + now = datetime.now() + + # Query for the refresh token + db_refresh_token = ( + session.query(OAuthRefreshToken) + .filter( + OAuthRefreshToken.token == refresh_token, + OAuthRefreshToken.client_id == client.client_id, + OAuthRefreshToken.revoked == False, # noqa: E712 + OAuthRefreshToken.expires_at > now, + ) + .first() + ) + + if not db_refresh_token: + logger.error( + f"Invalid or expired refresh token: {refresh_token[:20]}..." + ) + return None + + return RefreshToken( + token=refresh_token, + client_id=client.client_id, + scopes=cast(list[str], db_refresh_token.scopes), + expires_at=int(db_refresh_token.expires_at.timestamp()), + ) async def exchange_refresh_token( self, @@ -211,18 +323,71 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): refresh_token: RefreshToken, scopes: list[str], ) -> OAuthToken: - """Exchange refresh token - not supported in this simple implementation.""" - raise NotImplementedError("Refresh tokens not supported") + """Exchange refresh token for new access token.""" + with make_session() as session: + # Load the refresh token from database + db_refresh_token = ( + session.query(OAuthRefreshToken) + .filter( + OAuthRefreshToken.token == refresh_token.token, + OAuthRefreshToken.client_id == client.client_id, + OAuthRefreshToken.revoked == False, # noqa: E712 + ) + .first() + ) + + if not db_refresh_token: + logger.error(f"Refresh token not found: {refresh_token.token[:20]}...") + raise ValueError("Invalid refresh token") + + # Validate refresh token + validate_refresh_token(db_refresh_token) + + # Validate requested scopes are subset of original scopes + original_scopes = set(cast(list[str], db_refresh_token.scopes)) + requested_scopes = set(scopes) if scopes else original_scopes + + if not requested_scopes.issubset(original_scopes): + logger.error( + f"Requested scopes {requested_scopes} exceed original scopes {original_scopes}" + ) + raise ValueError("Requested scopes exceed original authorization") + + return make_token(session, db_refresh_token, scopes) async def revoke_token( self, token: str, token_type_hint: Optional[str] = None ) -> None: - """Revoke a token.""" + """Revoke a token (access token or refresh token).""" with make_session() as session: - user_session = session.query(UserSession).get(token) - if user_session: - session.delete(user_session) + revoked = False + + # Try to revoke as access token (UserSession) + if not token_type_hint or token_type_hint == "access_token": + user_session = session.query(UserSession).get(token) + if user_session: + session.delete(user_session) + revoked = True + logger.info(f"Revoked access token: {token[:20]}...") + + # Try to revoke as refresh token + if not revoked and ( + not token_type_hint or token_type_hint == "refresh_token" + ): + refresh_token = ( + session.query(OAuthRefreshToken) + .filter(OAuthRefreshToken.token == token) + .first() + ) + if refresh_token: + refresh_token.revoked = True # type: ignore + revoked = True + logger.info(f"Revoked refresh token: {token[:20]}...") + + if revoked: session.commit() + else: + logger.warning(f"Token not found for revocation: {token[:20]}...") def get_protected_resource_metadata(self) -> dict[str, Any]: """Return metadata about the protected resource.""" @@ -231,6 +396,9 @@ class SimpleOAuthProvider(OAuthAuthorizationServerProvider): "scopes_supported": ["read", "write"], "bearer_methods_supported": ["header"], "resource_documentation": f"{settings.SERVER_URL}/docs", + "grant_types_supported": ["authorization_code", "refresh_token"], + "token_endpoint_auth_methods_supported": ["none", "client_secret_basic"], + "refresh_token_rotation_enabled": True, "protected_resources": [ { "resource_uri": f"{settings.SERVER_URL}/mcp", diff --git a/src/memory/common/db/models/__init__.py b/src/memory/common/db/models/__init__.py index 352454d..8888a8f 100644 --- a/src/memory/common/db/models/__init__.py +++ b/src/memory/common/db/models/__init__.py @@ -37,6 +37,7 @@ from memory.common.db.models.users import ( UserSession, OAuthClientInformation, OAuthState, + OAuthRefreshToken, ) __all__ = [ @@ -73,4 +74,5 @@ __all__ = [ "UserSession", "OAuthClientInformation", "OAuthState", + "OAuthRefreshToken", ] diff --git a/src/memory/common/db/models/users.py b/src/memory/common/db/models/users.py index ef48be5..5c81d77 100644 --- a/src/memory/common/db/models/users.py +++ b/src/memory/common/db/models/users.py @@ -15,6 +15,7 @@ from sqlalchemy import ( ) from sqlalchemy.sql import func from sqlalchemy.orm import relationship +from datetime import datetime def hash_password(password: str) -> str: @@ -69,7 +70,7 @@ class UserSession(Base): __tablename__ = "user_sessions" id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) - oauth_state_id = Column(String, ForeignKey("oauth_states.state"), nullable=True) + oauth_state_id = Column(Integer, ForeignKey("oauth_states.id"), nullable=True) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime, server_default=func.now()) expires_at = Column(DateTime, nullable=False) @@ -125,29 +126,38 @@ class OAuthClientInformation(Base): } -class OAuthState(Base): - __tablename__ = "oauth_states" - - state = Column(String, primary_key=True) +class OAuthToken: + id = Column(Integer, primary_key=True) client_id = Column(String, ForeignKey("oauth_client.client_id"), nullable=False) user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + scopes = Column(ARRAY(String), nullable=False) + created_at = Column(DateTime, server_default=func.now()) + expires_at = Column(DateTime, nullable=False) + + def serialize(self) -> dict: + return { + "client_id": self.client_id, + "scopes": self.scopes, + "expires_at": self.expires_at.timestamp(), + } + + +class OAuthState(Base, OAuthToken): + __tablename__ = "oauth_states" + + state = Column(String, nullable=False) code = Column(String, nullable=True) redirect_uri = Column(String, nullable=False) redirect_uri_provided_explicitly = Column(Boolean, nullable=False) code_challenge = Column(String, nullable=True) - scopes = Column(ARRAY(String), nullable=False) - created_at = Column(DateTime, server_default=func.now()) - expires_at = Column(DateTime, nullable=False) stale = Column(Boolean, nullable=False, default=False) def serialize(self, code: bool = False) -> dict: data = { - "client_id": self.client_id, "redirect_uri": self.redirect_uri, "redirect_uri_provided_explicitly": self.redirect_uri_provided_explicitly, "code_challenge": self.code_challenge, - "scopes": self.scopes, - } + } | super().serialize() if code: data |= { "code": self.code, @@ -158,3 +168,29 @@ class OAuthState(Base): client = relationship("OAuthClientInformation", back_populates="sessions") session = relationship("UserSession", back_populates="oauth_state", uselist=False) user = relationship("User", back_populates="oauth_states") + + +class OAuthRefreshToken(Base, OAuthToken): + __tablename__ = "oauth_refresh_tokens" + + token = Column( + String, nullable=False, default=lambda: f"rt_{secrets.token_hex(32)}" + ) + revoked = Column(Boolean, nullable=False, default=False) + + # Optional: link to the access token session that was created with this refresh token + access_token_session_id = Column( + String, ForeignKey("user_sessions.id"), nullable=True + ) + + # Relationships + client = relationship("OAuthClientInformation") + user = relationship("User") + access_token_session = relationship("UserSession") + + def serialize(self) -> dict: + return { + "token": self.token, + "expires_at": self.expires_at.timestamp(), + "revoked": self.revoked, + } | super().serialize()